Skip to content

Commit ce1927d

Browse files
authored
Merge pull request #52 from MollySophia/deserialize
Reduce deserialization memory footprint on iOS devices
2 parents 1c984b7 + d27ceec commit ce1927d

File tree

7 files changed

+30
-40
lines changed

7 files changed

+30
-40
lines changed

crates/web-rwkv-derive/src/serde/de.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -542,9 +542,9 @@ fn deserialize_seq(
542542
.filter(|field| !field.attrs.skip_deserializing())
543543
.count();
544544
let expecting = if deserialized_count == 1 {
545-
format!("{} with 1 element", expecting)
545+
format!("{expecting} with 1 element")
546546
} else {
547-
format!("{} with {} elements", expecting, deserialized_count)
547+
format!("{expecting} with {deserialized_count} elements")
548548
};
549549
let expecting = cattrs.expecting().unwrap_or(&expecting);
550550

@@ -1126,7 +1126,7 @@ fn deserialize_adjacently_tagged_enum(
11261126
.collect();
11271127

11281128
let rust_name = params.type_name();
1129-
let expecting = format!("adjacently tagged enum {}", rust_name);
1129+
let expecting = format!("adjacently tagged enum {rust_name}");
11301130
let expecting = cattrs.expecting().unwrap_or(&expecting);
11311131
let type_name = cattrs.name().deserialize_name();
11321132
let deny_unknown_fields = cattrs.deny_unknown_fields();
@@ -2377,7 +2377,7 @@ fn deserialize_map(
23772377
}
23782378

23792379
fn field_i(i: usize) -> Ident {
2380-
Ident::new(&format!("__field{}", i), Span::call_site())
2380+
Ident::new(&format!("__field{i}"), Span::call_site())
23812381
}
23822382

23832383
/// This function wraps the expression in `#[serde(deserialize_with = "...")]`

crates/web-rwkv-derive/src/serde/internals/attr.rs

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ impl Container {
556556
} else {
557557
let path = meta.path.to_token_stream().to_string().replace(' ', "");
558558
return Err(
559-
meta.error(format_args!("unknown serde container attribute `{}`", path))
559+
meta.error(format_args!("unknown serde container attribute `{path}`"))
560560
);
561561
}
562562
Ok(())
@@ -974,7 +974,7 @@ impl Variant {
974974
} else {
975975
let path = meta.path.to_token_stream().to_string().replace(' ', "");
976976
return Err(
977-
meta.error(format_args!("unknown serde variant attribute `{}`", path))
977+
meta.error(format_args!("unknown serde variant attribute `{path}`"))
978978
);
979979
}
980980
Ok(())
@@ -1129,8 +1129,7 @@ impl Field {
11291129
if let Some(lifetimes) = &borrow_attribute.lifetimes {
11301130
for lifetime in lifetimes {
11311131
if !borrowable.contains(lifetime) {
1132-
let msg =
1133-
format!("field `{}` does not have lifetime {}", ident, lifetime);
1132+
let msg = format!("field `{ident}` does not have lifetime {lifetime}");
11341133
cx.error_spanned_by(field, msg);
11351134
}
11361135
}
@@ -1232,8 +1231,7 @@ impl Field {
12321231
for lifetime in &lifetimes {
12331232
if !borrowable.contains(lifetime) {
12341233
let msg = format!(
1235-
"field `{}` does not have lifetime {}",
1236-
ident, lifetime,
1234+
"field `{ident}` does not have lifetime {lifetime}",
12371235
);
12381236
cx.error_spanned_by(field, msg);
12391237
}
@@ -1256,9 +1254,7 @@ impl Field {
12561254
flatten.set_true(&meta.path);
12571255
} else {
12581256
let path = meta.path.to_token_stream().to_string().replace(' ', "");
1259-
return Err(
1260-
meta.error(format_args!("unknown serde field attribute `{}`", path))
1261-
);
1257+
return Err(meta.error(format_args!("unknown serde field attribute `{path}`")));
12621258
}
12631259
Ok(())
12641260
}) {
@@ -1451,8 +1447,7 @@ where
14511447
}
14521448
} else {
14531449
return Err(meta.error(format_args!(
1454-
"malformed {0} attribute, expected `{0}(serialize = ..., deserialize = ...)`",
1455-
attr_name,
1450+
"malformed {attr_name} attribute, expected `{attr_name}(serialize = ..., deserialize = ...)`",
14561451
)));
14571452
}
14581453
Ok(())
@@ -1517,16 +1512,15 @@ fn get_lit_str2(
15171512
if !suffix.is_empty() {
15181513
cx.error_spanned_by(
15191514
lit,
1520-
format!("unexpected suffix `{}` on string literal", suffix),
1515+
format!("unexpected suffix `{suffix}` on string literal"),
15211516
);
15221517
}
15231518
Ok(Some(lit.clone()))
15241519
} else {
15251520
cx.error_spanned_by(
15261521
expr,
15271522
format!(
1528-
"expected serde {} attribute to be a string: `{} = \"...\"`",
1529-
attr_name, meta_item_name
1523+
"expected serde {attr_name} attribute to be a string: `{meta_item_name} = \"...\"`"
15301524
),
15311525
);
15321526
Ok(None)
@@ -1637,10 +1631,7 @@ fn parse_lit_into_lifetimes(
16371631
while !input.is_empty() {
16381632
let lifetime: Lifetime = input.parse()?;
16391633
if !set.insert(lifetime.clone()) {
1640-
cx.error_spanned_by(
1641-
&string,
1642-
format!("duplicate borrowed lifetime `{}`", lifetime),
1643-
);
1634+
cx.error_spanned_by(&string, format!("duplicate borrowed lifetime `{lifetime}`"));
16441635
}
16451636
if input.is_empty() {
16461637
break;
@@ -1813,7 +1804,7 @@ fn borrowable_lifetimes(
18131804
let mut lifetimes = BTreeSet::new();
18141805
collect_lifetimes(&field.ty, &mut lifetimes);
18151806
if lifetimes.is_empty() {
1816-
let msg = format!("field `{}` has no lifetimes to borrow", name);
1807+
let msg = format!("field `{name}` has no lifetimes to borrow");
18171808
cx.error_spanned_by(field, msg);
18181809
Err(())
18191810
} else {

crates/web-rwkv-derive/src/serde/internals/check.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ fn check_default_on_tuple(cx: &Ctxt, cont: &Container) {
3838
if let Some(first) = first_default_index {
3939
cx.error_spanned_by(
4040
field.ty,
41-
format!("field must have #[serde(default)] because previous field {} has #[serde(default)]", first),
41+
format!("field must have #[serde(default)] because previous field {first} has #[serde(default)]"),
4242
);
4343
}
4444
continue;
@@ -311,7 +311,7 @@ fn check_internal_tag_field_name_conflict(cx: &Ctxt, cont: &Container) {
311311
let diagnose_conflict = || {
312312
cx.error_spanned_by(
313313
cont.original,
314-
format!("variant field name `{}` conflicts with internal tag", tag),
314+
format!("variant field name `{tag}` conflicts with internal tag"),
315315
);
316316
};
317317

@@ -358,10 +358,7 @@ fn check_adjacent_tag_conflict(cx: &Ctxt, cont: &Container) {
358358
if type_tag == content_tag {
359359
cx.error_spanned_by(
360360
cont.original,
361-
format!(
362-
"enum tags `{}` for type and content conflict with each other",
363-
type_tag
364-
),
361+
format!("enum tags `{type_tag}` for type and content conflict with each other"),
365362
);
366363
}
367364
}
@@ -447,7 +444,7 @@ fn check_transparent(cx: &Ctxt, cont: &mut Container, derive: Derive) {
447444

448445
fn member_message(member: &Member) -> String {
449446
match member {
450-
Member::Named(ident) => format!("`{}`", ident),
447+
Member::Named(ident) => format!("`{ident}`"),
451448
Member::Unnamed(i) => format!("#{}", i.index),
452449
}
453450
}

src/context.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ impl ContextBuilder {
155155
let data = read_back_buffer(&device, &buffer);
156156
let _ = sender.send(data);
157157
}
158-
log::info!("context dropped: {}", id);
158+
log::info!("context dropped: {id}");
159159
});
160160
}
161161

src/runtime/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ where
100100
let handle = tokio::spawn(Self::run(bundle.into(), receiver));
101101
tokio::spawn(async move {
102102
if let Err(err) = handle.await {
103-
log::error!("{}", err);
103+
log::error!("{err}");
104104
}
105105
});
106106
Self(sender)

src/tensor/ops.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,25 +168,25 @@ impl std::fmt::Display for Activation {
168168
impl Macros {
169169
/// Define a `u32` macro `NF4_BLOCK_SIZE`.
170170
pub fn nf4(mut self, block_size: u32) -> Self {
171-
self.insert("NF4_BLOCK_SIZE".into(), format!("{}u", block_size));
171+
self.insert("NF4_BLOCK_SIZE".into(), format!("{block_size}u"));
172172
self
173173
}
174174

175175
/// Define a `u32` macro `NF4_BLOCK_SIZE`.
176176
pub fn int8(mut self, block_size: u32) -> Self {
177-
self.insert("INT8_BLOCK_SIZE".into(), format!("{}u", block_size));
177+
self.insert("INT8_BLOCK_SIZE".into(), format!("{block_size}u"));
178178
self
179179
}
180180

181181
/// Define a `f32` macro with a given name.
182182
pub fn f32(mut self, name: impl Into<String>, value: f32) -> Self {
183-
self.insert(name.into(), format!("{}", value));
183+
self.insert(name.into(), format!("{value}"));
184184
self
185185
}
186186

187187
/// Define a `usize` macro with a given name.
188188
pub fn u32(mut self, name: impl Into<String>, value: u32) -> Self {
189-
self.insert(name.into(), format!("{}u", value));
189+
self.insert(name.into(), format!("{value}u"));
190190
self
191191
}
192192

@@ -254,8 +254,8 @@ fn custom_tanh(x: vec4<f32>) -> vec4<f32> {
254254
/// Define a macro with custom display name and prefix.
255255
pub fn custom(mut self, value: impl std::fmt::Display, prefix: Option<&'_ str>) -> Self {
256256
match prefix {
257-
None => self.insert(format!("{}", value), Default::default()),
258-
Some(prefix) => self.insert(format!("{}_{}", prefix, value), Default::default()),
257+
None => self.insert(format!("{value}"), Default::default()),
258+
Some(prefix) => self.insert(format!("{prefix}_{value}"), Default::default()),
259259
};
260260
self
261261
}
@@ -273,7 +273,7 @@ fn custom_tanh(x: vec4<f32>) -> vec4<f32> {
273273
pub fn subgroup(self, min: u32, max: u32) -> Self {
274274
self.u32("MIN_SUBGROUP_SIZE", min)
275275
.u32("MAX_SUBGROUP_SIZE", max)
276-
.define(format!("SUBGROUP_SIZE_{}_{}", min, max), true)
276+
.define(format!("SUBGROUP_SIZE_{min}_{max}"), true)
277277
}
278278
}
279279

src/tensor/serialization.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ impl<'de, T: Scalar + Deserialize<'de>, K: Kind> DeserializeSeed<'de>
144144
{
145145
let context = &self.context;
146146
let tensor: TensorBlob<'de> = Deserialize::deserialize(deserializer)?;
147-
TensorGpu::from_data_u8(context, tensor.shape, &tensor.data).map_err(D::Error::custom)
147+
let tensor = TensorGpu::from_data_u8(context, tensor.shape, &tensor.data);
148+
context.queue.submit(None);
149+
tensor.map_err(D::Error::custom)
148150
}
149151
}
150152

0 commit comments

Comments
 (0)