1+ use half:: { bf16, f16} ;
12use serde:: {
23 ser:: {
34 Error as _, Impossible , SerializeMap , SerializeSeq , SerializeStruct ,
@@ -7,7 +8,53 @@ use serde::{
78} ;
89use serde_json:: error:: Error ;
910
10- use crate :: { array:: TryCollect , DestructuredRef , IArray , INumber , IObject , IString , IValue } ;
11+ use crate :: {
12+ array:: { ArraySliceRef , TryCollect } ,
13+ DestructuredRef , IArray , INumber , IObject , IString , IValue ,
14+ } ;
15+
16+ /// Rounds an f64 to the given number of significant decimal digits.
17+ #[ inline]
18+ fn round_to_sig_digits ( val : f64 , sig_digits : u32 ) -> f64 {
19+ // how many digits are left of the decimal point, minus one
20+ // OR: what power of 10 is this number closest to? (e.g. 100 -> 10^2, 0.001 -> 10^-3)
21+ let order_of_magnitude = val. abs ( ) . log10 ( ) . floor ( ) as i32 ;
22+ // Multiplier that shifts the desired significant digits into the integer part,
23+ // so that f64::round() can snap to the nearest integer and discard the rest.
24+ // e.g. for val=3.14, order_of_magnitude = 0, sig_digits=2: scale=10, so 3.14*10=31.4 → round→31 → 31/10=3.1
25+ let scale = 10f64 . powi ( sig_digits as i32 - 1 - order_of_magnitude) ;
26+ ( val * scale) . round ( ) / scale
27+ }
28+
29+ /// Finds an f64 value that, when formatted by ryu's f64 algorithm, produces
30+ /// the shortest decimal string that still round-trips through the target
31+ /// half-precision type (f16 or bf16).
32+ ///
33+ /// ryu only supports f32/f64, and serde has no `serialize_f16`. Since f16/bf16
34+ /// have far fewer distinct values than f32, there exist shorter representations
35+ /// that uniquely identify the half value. For example, f16(0.3) = 0.300048828125,
36+ /// and "0.3" parsed as f16 gives back the same bits — so "0.3" is valid.
37+ ///
38+ /// The approach: try rounding to increasing significant digits until the
39+ /// rounded value round-trips through the type. Then return that f64
40+ /// value, so that `serialize_f64` (via ryu) reproduces it.
41+ fn find_shortest_roundtrip_f64 ( f64_val : f64 , roundtrips : impl Fn ( f64 ) -> bool ) -> f64 {
42+ if !f64_val. is_finite ( ) || f64_val. fract ( ) == 0.0 {
43+ return f64_val;
44+ }
45+ // With our usage(F16/BF16), the loop will need only ~4 iterations, since max significant digits needed is ~4
46+ // Example: f16(3.14159) stores 3.140625
47+ // sig_digits=1 → 3.0 → f16(3.0)=3.0 ≠ 3.140625 ❌
48+ // sig_digits=2 → 3.1 → f16(3.1)=3.099.. ≠ 3.140625 ❌
49+ // sig_digits=3 → 3.14 → f16(3.14)=3.140625 ✅ → returns 3.14
50+ for sig_digits in 1 ..=5u32 {
51+ let rounded = round_to_sig_digits ( f64_val, sig_digits) ;
52+ if roundtrips ( rounded) {
53+ return rounded;
54+ }
55+ }
56+ f64_val
57+ }
1158
1259impl Serialize for IValue {
1360 #[ inline]
@@ -55,11 +102,50 @@ impl Serialize for IArray {
55102 where
56103 S : Serializer ,
57104 {
58- let mut s = serializer. serialize_seq ( Some ( self . len ( ) ) ) ?;
59- for v in self {
60- s. serialize_element ( & v) ?;
105+ match self . as_slice ( ) {
106+ // Serialize typed float arrays with the shortest representation that
107+ // round-trips through the stored precision. Without this, all floats
108+ // would be promoted to f64 via INumber, and ryu's f64 algorithm would
109+ // emit unnecessarily long strings (e.g. "0.3" stored as f32 would
110+ // serialize as "0.30000001192092896" instead of "0.3").
111+ //
112+ // F32: serialize directly as f32 so ryu uses its f32 algorithm.
113+ // F16/BF16: ryu has no f16 mode and serde has no serialize_f16, so we
114+ // find the shortest decimal that round-trips through the half type and
115+ // pass the corresponding f64 value to serialize_f64.
116+ ArraySliceRef :: F32 ( slice) => {
117+ let mut s = serializer. serialize_seq ( Some ( slice. len ( ) ) ) ?;
118+ for & v in slice {
119+ s. serialize_element ( & v) ?;
120+ }
121+ s. end ( )
122+ }
123+ ArraySliceRef :: F16 ( slice) => {
124+ let mut s = serializer. serialize_seq ( Some ( slice. len ( ) ) ) ?;
125+ for & v in slice {
126+ let f64_val = f64:: from ( v) ;
127+ let shortest = find_shortest_roundtrip_f64 ( f64_val, |p| f16:: from_f64 ( p) == v) ;
128+ s. serialize_element ( & shortest) ?;
129+ }
130+ s. end ( )
131+ }
132+ ArraySliceRef :: BF16 ( slice) => {
133+ let mut s = serializer. serialize_seq ( Some ( slice. len ( ) ) ) ?;
134+ for & v in slice {
135+ let f64_val = f64:: from ( v) ;
136+ let shortest = find_shortest_roundtrip_f64 ( f64_val, |p| bf16:: from_f64 ( p) == v) ;
137+ s. serialize_element ( & shortest) ?;
138+ }
139+ s. end ( )
140+ }
141+ _ => {
142+ let mut s = serializer. serialize_seq ( Some ( self . len ( ) ) ) ?;
143+ for v in self {
144+ s. serialize_element ( & v) ?;
145+ }
146+ s. end ( )
147+ }
61148 }
62- s. end ( )
63149 }
64150}
65151
@@ -635,3 +721,202 @@ where
635721{
636722 value. serialize ( ValueSerializer )
637723}
724+
725+ #[ cfg( test) ]
726+ mod tests {
727+ use crate :: array:: { ArraySliceRef , FloatType } ;
728+ use crate :: { FPHAConfig , IArray , IValue , IValueDeserSeed } ;
729+
730+ #[ test]
731+ fn test_f32_array_serialization_preserves_short_representation ( ) {
732+ let mut arr = IArray :: new ( ) ;
733+ arr. push_with_fp_type ( IValue :: from ( 0.3 ) , FloatType :: F32 )
734+ . unwrap ( ) ;
735+ assert ! ( matches!( arr. as_slice( ) , ArraySliceRef :: F32 ( _) ) ) ;
736+
737+ let json = serde_json:: to_string ( & arr) . unwrap ( ) ;
738+ assert_eq ! (
739+ json, "[0.3]" ,
740+ "F32 array should serialize 0.3 as '0.3', not with extra f64 precision digits"
741+ ) ;
742+ }
743+
744+ #[ test]
745+ fn test_f64_array_serialization_preserves_short_representation ( ) {
746+ let mut arr = IArray :: new ( ) ;
747+ arr. push_with_fp_type ( IValue :: from ( 0.3 ) , FloatType :: F64 )
748+ . unwrap ( ) ;
749+ assert ! ( matches!( arr. as_slice( ) , ArraySliceRef :: F64 ( _) ) ) ;
750+
751+ let json = serde_json:: to_string ( & arr) . unwrap ( ) ;
752+ assert_eq ! ( json, "[0.3]" ) ;
753+ }
754+
755+ #[ test]
756+ fn test_f16_array_serialization_preserves_short_representation ( ) {
757+ let mut arr = IArray :: new ( ) ;
758+ arr. push_with_fp_type ( IValue :: from ( 1.5 ) , FloatType :: F16 )
759+ . unwrap ( ) ;
760+ assert_eq ! ( serde_json:: to_string( & arr) . unwrap( ) , "[1.5]" ) ;
761+
762+ let mut arr2 = IArray :: new ( ) ;
763+ arr2. push_with_fp_type ( IValue :: from ( 0.3 ) , FloatType :: F16 )
764+ . unwrap ( ) ;
765+ assert_eq ! (
766+ serde_json:: to_string( & arr2) . unwrap( ) ,
767+ "[0.3]" ,
768+ "F16 array should serialize 0.3 as '0.3', not '0.30004883' or '0.300048828125'"
769+ ) ;
770+ }
771+
772+ #[ test]
773+ fn test_bf16_array_serialization_preserves_short_representation ( ) {
774+ let mut arr = IArray :: new ( ) ;
775+ arr. push_with_fp_type ( IValue :: from ( 1.5 ) , FloatType :: BF16 )
776+ . unwrap ( ) ;
777+ assert_eq ! ( serde_json:: to_string( & arr) . unwrap( ) , "[1.5]" ) ;
778+
779+ let mut arr2 = IArray :: new ( ) ;
780+ arr2. push_with_fp_type ( IValue :: from ( 0.3 ) , FloatType :: BF16 )
781+ . unwrap ( ) ;
782+ assert_eq ! (
783+ serde_json:: to_string( & arr2) . unwrap( ) ,
784+ "[0.3]" ,
785+ "BF16 array should serialize 0.3 as '0.3'"
786+ ) ;
787+ }
788+
789+ #[ test]
790+ fn test_typed_float_array_serialization_roundtrip ( ) {
791+ let input = "[0.3,0.1,0.7,1.0,2.5,100.0]" ;
792+ let fp_types = [
793+ FloatType :: F16 ,
794+ FloatType :: BF16 ,
795+ FloatType :: F32 ,
796+ FloatType :: F64 ,
797+ ] ;
798+
799+ let jsons: Vec < String > = fp_types
800+ . iter ( )
801+ . map ( |& fp_type| {
802+ let seed = IValueDeserSeed :: new ( Some ( FPHAConfig :: new_with_type ( fp_type) ) ) ;
803+ let mut de = serde_json:: Deserializer :: from_str ( input) ;
804+ let arr = serde:: de:: DeserializeSeed :: deserialize ( seed, & mut de)
805+ . unwrap ( )
806+ . into_array ( )
807+ . unwrap ( ) ;
808+ let json_out = serde_json:: to_string ( & arr) . unwrap ( ) ;
809+ assert_eq ! (
810+ json_out, input,
811+ "{fp_type} round-trip should preserve the original JSON string"
812+ ) ;
813+ json_out
814+ } )
815+ . collect ( ) ;
816+
817+ for pair in jsons. windows ( 2 ) {
818+ assert_eq ! (
819+ pair[ 0 ] , pair[ 1 ] ,
820+ "all float types should produce identical JSON"
821+ ) ;
822+ }
823+ }
824+
825+ #[ test]
826+ fn test_f16_precision_loss_produces_different_but_short_representation ( ) {
827+ // Values with more significant digits than f16 can represent (~3.3 digits).
828+ // The stored f16 value differs from the original, so the serialized string
829+ // must differ too — but it should still be the shortest string that
830+ // round-trips through f16.
831+ let cases: & [ ( & str , & str ) ] = & [
832+ ( "3.14159" , "3.14" ) , // pi truncated: f16 stores 3.140625
833+ ( "42.42" , "42.4" ) , // f16 stores 42.40625
834+ ( "12.345" , "12.34" ) , // f16 stores 12.34375
835+ ( "0.5678" , "0.568" ) , // f16 stores 0.56787109375
836+ ] ;
837+
838+ for & ( input, expected_f16) in cases {
839+ let json_input = format ! ( "[{input}]" ) ;
840+
841+ let f16_arr: IArray = {
842+ let seed = IValueDeserSeed :: new ( Some ( FPHAConfig :: new_with_type ( FloatType :: F16 ) ) ) ;
843+ let mut de = serde_json:: Deserializer :: from_str ( & json_input) ;
844+ serde:: de:: DeserializeSeed :: deserialize ( seed, & mut de)
845+ . unwrap ( )
846+ . into_array ( )
847+ . unwrap ( )
848+ } ;
849+ let f16_json = serde_json:: to_string ( & f16_arr) . unwrap ( ) ;
850+ assert_eq ! (
851+ f16_json,
852+ format!( "[{expected_f16}]" ) ,
853+ "F16 of {input}: should serialize as shortest f16 representation"
854+ ) ;
855+
856+ // Same values through F32 should preserve the original (enough precision)
857+ let f32_arr: IArray = {
858+ let seed = IValueDeserSeed :: new ( Some ( FPHAConfig :: new_with_type ( FloatType :: F32 ) ) ) ;
859+ let mut de = serde_json:: Deserializer :: from_str ( & json_input) ;
860+ serde:: de:: DeserializeSeed :: deserialize ( seed, & mut de)
861+ . unwrap ( )
862+ . into_array ( )
863+ . unwrap ( )
864+ } ;
865+ let f32_json = serde_json:: to_string ( & f32_arr) . unwrap ( ) ;
866+ assert_eq ! (
867+ f32_json, json_input,
868+ "F32 of {input}: should preserve the original representation"
869+ ) ;
870+ }
871+ }
872+
873+ #[ test]
874+ fn test_negative_float_array_serialization ( ) {
875+ let input = "[-0.3,-0.1,-1.0,-2.5,-100.0]" ;
876+ let fp_types = [
877+ FloatType :: F16 ,
878+ FloatType :: BF16 ,
879+ FloatType :: F32 ,
880+ FloatType :: F64 ,
881+ ] ;
882+
883+ for & fp_type in & fp_types {
884+ let seed = IValueDeserSeed :: new ( Some ( FPHAConfig :: new_with_type ( fp_type) ) ) ;
885+ let mut de = serde_json:: Deserializer :: from_str ( input) ;
886+ let arr = serde:: de:: DeserializeSeed :: deserialize ( seed, & mut de)
887+ . unwrap ( )
888+ . into_array ( )
889+ . unwrap ( ) ;
890+ let json_out = serde_json:: to_string ( & arr) . unwrap ( ) ;
891+ assert_eq ! (
892+ json_out, input,
893+ "{fp_type} negative round-trip should preserve the original JSON string"
894+ ) ;
895+ }
896+ }
897+
898+ #[ test]
899+ fn test_negative_f16_precision_loss_produces_short_representation ( ) {
900+ let cases: & [ ( & str , & str ) ] = & [
901+ ( "-3.14159" , "-3.14" ) ,
902+ ( "-42.42" , "-42.4" ) ,
903+ ( "-0.5678" , "-0.568" ) ,
904+ ] ;
905+
906+ for & ( input, expected_f16) in cases {
907+ let json_input = format ! ( "[{input}]" ) ;
908+ let seed = IValueDeserSeed :: new ( Some ( FPHAConfig :: new_with_type ( FloatType :: F16 ) ) ) ;
909+ let mut de = serde_json:: Deserializer :: from_str ( & json_input) ;
910+ let arr = serde:: de:: DeserializeSeed :: deserialize ( seed, & mut de)
911+ . unwrap ( )
912+ . into_array ( )
913+ . unwrap ( ) ;
914+ let json_out = serde_json:: to_string ( & arr) . unwrap ( ) ;
915+ assert_eq ! (
916+ json_out,
917+ format!( "[{expected_f16}]" ) ,
918+ "F16 of {input}: negative should serialize as shortest f16 representation"
919+ ) ;
920+ }
921+ }
922+ }
0 commit comments