Skip to content

Commit 3dcd744

Browse files
authored
MOD-14559 Fix float serialization precision for typed FPHA arrays (#19)
* MOD-14559 Fix float serialization precision for typed FPHA arrays (F32/F16/BF16)
1 parent 446885b commit 3dcd744

1 file changed

Lines changed: 290 additions & 5 deletions

File tree

src/ser.rs

Lines changed: 290 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use half::{bf16, f16};
12
use serde::{
23
ser::{
34
Error as _, Impossible, SerializeMap, SerializeSeq, SerializeStruct,
@@ -7,7 +8,53 @@ use serde::{
78
};
89
use serde_json::error::Error;
910

10-
use crate::{array::TryCollect, DestructuredRef, IArray, INumber, IObject, IString, IValue};
11+
use crate::{
12+
array::{ArraySliceRef, TryCollect},
13+
DestructuredRef, IArray, INumber, IObject, IString, IValue,
14+
};
15+
16+
/// Rounds an f64 to the given number of significant decimal digits.
17+
#[inline]
18+
fn round_to_sig_digits(val: f64, sig_digits: u32) -> f64 {
19+
// how many digits are left of the decimal point, minus one
20+
// OR: what power of 10 is this number closest to? (e.g. 100 -> 10^2, 0.001 -> 10^-3)
21+
let order_of_magnitude = val.abs().log10().floor() as i32;
22+
// Multiplier that shifts the desired significant digits into the integer part,
23+
// so that f64::round() can snap to the nearest integer and discard the rest.
24+
// e.g. for val=3.14, order_of_magnitude = 0, sig_digits=2: scale=10, so 3.14*10=31.4 → round→31 → 31/10=3.1
25+
let scale = 10f64.powi(sig_digits as i32 - 1 - order_of_magnitude);
26+
(val * scale).round() / scale
27+
}
28+
29+
/// Finds an f64 value that, when formatted by ryu's f64 algorithm, produces
30+
/// the shortest decimal string that still round-trips through the target
31+
/// half-precision type (f16 or bf16).
32+
///
33+
/// ryu only supports f32/f64, and serde has no `serialize_f16`. Since f16/bf16
34+
/// have far fewer distinct values than f32, there exist shorter representations
35+
/// that uniquely identify the half value. For example, f16(0.3) = 0.300048828125,
36+
/// and "0.3" parsed as f16 gives back the same bits — so "0.3" is valid.
37+
///
38+
/// The approach: try rounding to increasing significant digits until the
39+
/// rounded value round-trips through the type. Then return that f64
40+
/// value, so that `serialize_f64` (via ryu) reproduces it.
41+
fn find_shortest_roundtrip_f64(f64_val: f64, roundtrips: impl Fn(f64) -> bool) -> f64 {
42+
if !f64_val.is_finite() || f64_val.fract() == 0.0 {
43+
return f64_val;
44+
}
45+
// With our usage(F16/BF16), the loop will need only ~4 iterations, since max significant digits needed is ~4
46+
// Example: f16(3.14159) stores 3.140625
47+
// sig_digits=1 → 3.0 → f16(3.0)=3.0 ≠ 3.140625 ❌
48+
// sig_digits=2 → 3.1 → f16(3.1)=3.099.. ≠ 3.140625 ❌
49+
// sig_digits=3 → 3.14 → f16(3.14)=3.140625 ✅ → returns 3.14
50+
for sig_digits in 1..=5u32 {
51+
let rounded = round_to_sig_digits(f64_val, sig_digits);
52+
if roundtrips(rounded) {
53+
return rounded;
54+
}
55+
}
56+
f64_val
57+
}
1158

1259
impl Serialize for IValue {
1360
#[inline]
@@ -55,11 +102,50 @@ impl Serialize for IArray {
55102
where
56103
S: Serializer,
57104
{
58-
let mut s = serializer.serialize_seq(Some(self.len()))?;
59-
for v in self {
60-
s.serialize_element(&v)?;
105+
match self.as_slice() {
106+
// Serialize typed float arrays with the shortest representation that
107+
// round-trips through the stored precision. Without this, all floats
108+
// would be promoted to f64 via INumber, and ryu's f64 algorithm would
109+
// emit unnecessarily long strings (e.g. "0.3" stored as f32 would
110+
// serialize as "0.30000001192092896" instead of "0.3").
111+
//
112+
// F32: serialize directly as f32 so ryu uses its f32 algorithm.
113+
// F16/BF16: ryu has no f16 mode and serde has no serialize_f16, so we
114+
// find the shortest decimal that round-trips through the half type and
115+
// pass the corresponding f64 value to serialize_f64.
116+
ArraySliceRef::F32(slice) => {
117+
let mut s = serializer.serialize_seq(Some(slice.len()))?;
118+
for &v in slice {
119+
s.serialize_element(&v)?;
120+
}
121+
s.end()
122+
}
123+
ArraySliceRef::F16(slice) => {
124+
let mut s = serializer.serialize_seq(Some(slice.len()))?;
125+
for &v in slice {
126+
let f64_val = f64::from(v);
127+
let shortest = find_shortest_roundtrip_f64(f64_val, |p| f16::from_f64(p) == v);
128+
s.serialize_element(&shortest)?;
129+
}
130+
s.end()
131+
}
132+
ArraySliceRef::BF16(slice) => {
133+
let mut s = serializer.serialize_seq(Some(slice.len()))?;
134+
for &v in slice {
135+
let f64_val = f64::from(v);
136+
let shortest = find_shortest_roundtrip_f64(f64_val, |p| bf16::from_f64(p) == v);
137+
s.serialize_element(&shortest)?;
138+
}
139+
s.end()
140+
}
141+
_ => {
142+
let mut s = serializer.serialize_seq(Some(self.len()))?;
143+
for v in self {
144+
s.serialize_element(&v)?;
145+
}
146+
s.end()
147+
}
61148
}
62-
s.end()
63149
}
64150
}
65151

@@ -635,3 +721,202 @@ where
635721
{
636722
value.serialize(ValueSerializer)
637723
}
724+
725+
#[cfg(test)]
726+
mod tests {
727+
use crate::array::{ArraySliceRef, FloatType};
728+
use crate::{FPHAConfig, IArray, IValue, IValueDeserSeed};
729+
730+
#[test]
731+
fn test_f32_array_serialization_preserves_short_representation() {
732+
let mut arr = IArray::new();
733+
arr.push_with_fp_type(IValue::from(0.3), FloatType::F32)
734+
.unwrap();
735+
assert!(matches!(arr.as_slice(), ArraySliceRef::F32(_)));
736+
737+
let json = serde_json::to_string(&arr).unwrap();
738+
assert_eq!(
739+
json, "[0.3]",
740+
"F32 array should serialize 0.3 as '0.3', not with extra f64 precision digits"
741+
);
742+
}
743+
744+
#[test]
745+
fn test_f64_array_serialization_preserves_short_representation() {
746+
let mut arr = IArray::new();
747+
arr.push_with_fp_type(IValue::from(0.3), FloatType::F64)
748+
.unwrap();
749+
assert!(matches!(arr.as_slice(), ArraySliceRef::F64(_)));
750+
751+
let json = serde_json::to_string(&arr).unwrap();
752+
assert_eq!(json, "[0.3]");
753+
}
754+
755+
#[test]
756+
fn test_f16_array_serialization_preserves_short_representation() {
757+
let mut arr = IArray::new();
758+
arr.push_with_fp_type(IValue::from(1.5), FloatType::F16)
759+
.unwrap();
760+
assert_eq!(serde_json::to_string(&arr).unwrap(), "[1.5]");
761+
762+
let mut arr2 = IArray::new();
763+
arr2.push_with_fp_type(IValue::from(0.3), FloatType::F16)
764+
.unwrap();
765+
assert_eq!(
766+
serde_json::to_string(&arr2).unwrap(),
767+
"[0.3]",
768+
"F16 array should serialize 0.3 as '0.3', not '0.30004883' or '0.300048828125'"
769+
);
770+
}
771+
772+
#[test]
773+
fn test_bf16_array_serialization_preserves_short_representation() {
774+
let mut arr = IArray::new();
775+
arr.push_with_fp_type(IValue::from(1.5), FloatType::BF16)
776+
.unwrap();
777+
assert_eq!(serde_json::to_string(&arr).unwrap(), "[1.5]");
778+
779+
let mut arr2 = IArray::new();
780+
arr2.push_with_fp_type(IValue::from(0.3), FloatType::BF16)
781+
.unwrap();
782+
assert_eq!(
783+
serde_json::to_string(&arr2).unwrap(),
784+
"[0.3]",
785+
"BF16 array should serialize 0.3 as '0.3'"
786+
);
787+
}
788+
789+
#[test]
790+
fn test_typed_float_array_serialization_roundtrip() {
791+
let input = "[0.3,0.1,0.7,1.0,2.5,100.0]";
792+
let fp_types = [
793+
FloatType::F16,
794+
FloatType::BF16,
795+
FloatType::F32,
796+
FloatType::F64,
797+
];
798+
799+
let jsons: Vec<String> = fp_types
800+
.iter()
801+
.map(|&fp_type| {
802+
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(fp_type)));
803+
let mut de = serde_json::Deserializer::from_str(input);
804+
let arr = serde::de::DeserializeSeed::deserialize(seed, &mut de)
805+
.unwrap()
806+
.into_array()
807+
.unwrap();
808+
let json_out = serde_json::to_string(&arr).unwrap();
809+
assert_eq!(
810+
json_out, input,
811+
"{fp_type} round-trip should preserve the original JSON string"
812+
);
813+
json_out
814+
})
815+
.collect();
816+
817+
for pair in jsons.windows(2) {
818+
assert_eq!(
819+
pair[0], pair[1],
820+
"all float types should produce identical JSON"
821+
);
822+
}
823+
}
824+
825+
#[test]
826+
fn test_f16_precision_loss_produces_different_but_short_representation() {
827+
// Values with more significant digits than f16 can represent (~3.3 digits).
828+
// The stored f16 value differs from the original, so the serialized string
829+
// must differ too — but it should still be the shortest string that
830+
// round-trips through f16.
831+
let cases: &[(&str, &str)] = &[
832+
("3.14159", "3.14"), // pi truncated: f16 stores 3.140625
833+
("42.42", "42.4"), // f16 stores 42.40625
834+
("12.345", "12.34"), // f16 stores 12.34375
835+
("0.5678", "0.568"), // f16 stores 0.56787109375
836+
];
837+
838+
for &(input, expected_f16) in cases {
839+
let json_input = format!("[{input}]");
840+
841+
let f16_arr: IArray = {
842+
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16)));
843+
let mut de = serde_json::Deserializer::from_str(&json_input);
844+
serde::de::DeserializeSeed::deserialize(seed, &mut de)
845+
.unwrap()
846+
.into_array()
847+
.unwrap()
848+
};
849+
let f16_json = serde_json::to_string(&f16_arr).unwrap();
850+
assert_eq!(
851+
f16_json,
852+
format!("[{expected_f16}]"),
853+
"F16 of {input}: should serialize as shortest f16 representation"
854+
);
855+
856+
// Same values through F32 should preserve the original (enough precision)
857+
let f32_arr: IArray = {
858+
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F32)));
859+
let mut de = serde_json::Deserializer::from_str(&json_input);
860+
serde::de::DeserializeSeed::deserialize(seed, &mut de)
861+
.unwrap()
862+
.into_array()
863+
.unwrap()
864+
};
865+
let f32_json = serde_json::to_string(&f32_arr).unwrap();
866+
assert_eq!(
867+
f32_json, json_input,
868+
"F32 of {input}: should preserve the original representation"
869+
);
870+
}
871+
}
872+
873+
#[test]
874+
fn test_negative_float_array_serialization() {
875+
let input = "[-0.3,-0.1,-1.0,-2.5,-100.0]";
876+
let fp_types = [
877+
FloatType::F16,
878+
FloatType::BF16,
879+
FloatType::F32,
880+
FloatType::F64,
881+
];
882+
883+
for &fp_type in &fp_types {
884+
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(fp_type)));
885+
let mut de = serde_json::Deserializer::from_str(input);
886+
let arr = serde::de::DeserializeSeed::deserialize(seed, &mut de)
887+
.unwrap()
888+
.into_array()
889+
.unwrap();
890+
let json_out = serde_json::to_string(&arr).unwrap();
891+
assert_eq!(
892+
json_out, input,
893+
"{fp_type} negative round-trip should preserve the original JSON string"
894+
);
895+
}
896+
}
897+
898+
#[test]
899+
fn test_negative_f16_precision_loss_produces_short_representation() {
900+
let cases: &[(&str, &str)] = &[
901+
("-3.14159", "-3.14"),
902+
("-42.42", "-42.4"),
903+
("-0.5678", "-0.568"),
904+
];
905+
906+
for &(input, expected_f16) in cases {
907+
let json_input = format!("[{input}]");
908+
let seed = IValueDeserSeed::new(Some(FPHAConfig::new_with_type(FloatType::F16)));
909+
let mut de = serde_json::Deserializer::from_str(&json_input);
910+
let arr = serde::de::DeserializeSeed::deserialize(seed, &mut de)
911+
.unwrap()
912+
.into_array()
913+
.unwrap();
914+
let json_out = serde_json::to_string(&arr).unwrap();
915+
assert_eq!(
916+
json_out,
917+
format!("[{expected_f16}]"),
918+
"F16 of {input}: negative should serialize as shortest f16 representation"
919+
);
920+
}
921+
}
922+
}

0 commit comments

Comments
 (0)