diff --git a/src/libraries/Common/tests/TestUtilities/System/AssertExtensions.cs b/src/libraries/Common/tests/TestUtilities/System/AssertExtensions.cs index 880bfcb3831def..cc3de8dba7d995 100644 --- a/src/libraries/Common/tests/TestUtilities/System/AssertExtensions.cs +++ b/src/libraries/Common/tests/TestUtilities/System/AssertExtensions.cs @@ -1190,6 +1190,25 @@ public static void Equal(Half expected, Half actual, Half variance, string? bann throw EqualException.ForMismatchedValues(ToStringPadded(expected), ToStringPadded(actual), banner); } } + + /// Verifies that two values are equal, within the . + /// The expected value + /// The value to be compared against + /// The total variance allowed between the expected and actual results. + /// The banner to show; if null, then the standard + /// banner of "Values differ" will be used + /// Thrown when the values are not equal + public static void Equal(NFloat expected, NFloat actual, NFloat variance, string? banner = null) + { + if (NFloat.Size == 4) + { + Equal((float)expected, (float)actual, (float)variance, banner); + } + else + { + Equal((double)expected, (double)actual, (double)variance, banner); + } + } #endif /// Verifies that two values's binary representations are identical. @@ -1257,6 +1276,22 @@ public static void Equal(Half expected, Half actual) throw EqualException.ForMismatchedValues(ToStringPadded(expected), ToStringPadded(actual)); } + + /// Verifies that two values's binary representations are identical. + /// The expected value + /// The value to be compared against + /// Thrown when the representations are not identical + public static void Equal(NFloat expected, NFloat actual) + { + if (NFloat.Size == 4) + { + Equal((float)expected, (float)actual); + } + else + { + Equal((double)expected, (double)actual); + } + } #endif } } diff --git a/src/libraries/System.Numerics.Tensors/tests/Helpers.cs b/src/libraries/System.Numerics.Tensors/tests/Helpers.cs index 27a32a3b27a0ae..d3ebdecefa92d9 100644 --- a/src/libraries/System.Numerics.Tensors/tests/Helpers.cs +++ b/src/libraries/System.Numerics.Tensors/tests/Helpers.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Runtime.InteropServices; using Xunit; +using Xunit.Sdk; namespace System.Numerics.Tensors.Tests { @@ -34,28 +35,74 @@ private static class DefaultTolerance where T : unmanaged, INumber public static readonly T Value = DetermineTolerance(DefaultDoubleTolerance, DefaultFloatTolerance, Half.CreateTruncating(DefaultHalfTolerance)) ?? T.CreateTruncating(0); } - public static bool IsEqualWithTolerance(T expected, T actual, T? tolerance = null) where T : unmanaged, INumber + public static void AssertEqualWithTolerance(T expected, T actual, T? tolerance = null, string? banner = null) where T : unmanaged, INumber { - if (T.IsNaN(expected) != T.IsNaN(actual)) + T actualTolerance = tolerance ?? DefaultTolerance.Value; + try { - return false; + T scaledTolerance = checked(T.Max(T.Abs(expected), T.Abs(actual)) * actualTolerance); + if (T.IsFinite(scaledTolerance)) + { + actualTolerance = T.Max(scaledTolerance, actualTolerance); + } } + catch (OverflowException) { } // Multiplication and T.Abs can throw for integers, just keep the original tolerance in that case. - tolerance = tolerance ?? DefaultTolerance.Value; - T diff = T.Abs(expected - actual); - return !(diff > tolerance && diff > T.Max(T.Abs(expected), T.Abs(actual)) * tolerance); + // Delegate to AssertExtensions.Equal for special value comparisons (NaN, +-inf, +-0) + if (typeof(T) == typeof(double)) + { + AssertExtensions.Equal((double)(object)expected, (double)(object)actual, (double)(object)actualTolerance, banner); + } + else if (typeof(T) == typeof(float)) + { + AssertExtensions.Equal((float)(object)expected, (float)(object)actual, (float)(object)actualTolerance, banner); + } + else if (typeof(T) == typeof(Half)) + { + AssertExtensions.Equal((Half)(object)expected, (Half)(object)actual, (Half)(object)actualTolerance, banner); + } + else if (typeof(T) == typeof(NFloat)) + { + AssertExtensions.Equal((NFloat)(object)expected, (NFloat)(object)actual, (NFloat)(object)actualTolerance, banner); + } + else if (typeof(T) == typeof(sbyte) || typeof(T) == typeof(byte) || + typeof(T) == typeof(short) || typeof(T) == typeof(ushort) || typeof(T) == typeof(char) || + typeof(T) == typeof(int) || typeof(T) == typeof(uint) || + typeof(T) == typeof(long) || typeof(T) == typeof(ulong) || + typeof(T) == typeof(nint) || typeof(T) == typeof(nuint) || + typeof(T) == typeof(Int128) || typeof(T) == typeof(UInt128)) + { + T delta; + try + { + delta = T.Abs(checked(expected - actual)); + } + catch (OverflowException) + { + // Subtraction and T.Abs can throw for integers, in that case the mismatch is large enough to fail assertion + throw EqualException.ForMismatchedValues(expected.ToString(), actual.ToString(), banner); + } + if (delta > actualTolerance) + { + throw EqualException.ForMismatchedValues(expected.ToString(), actual.ToString(), banner); + } + } + else + { + throw new NotImplementedException($"Type not supported for {nameof(AssertEqualWithTolerance)}: {typeof(T).Name}"); + } } #else - public static bool IsEqualWithTolerance(float expected, float actual, float? tolerance = null) + public static void AssertEqualWithTolerance(float expected, float actual, float? tolerance = null, string? banner = null) { - if (float.IsNaN(expected) != float.IsNaN(actual)) + float actualTolerance = tolerance ?? DefaultFloatTolerance; + float scaledTolerance = MathF.Max(MathF.Abs(expected), MathF.Abs(actual)) * (tolerance ?? DefaultFloatTolerance); + if (!float.IsNaN(scaledTolerance) && !float.IsInfinity(scaledTolerance)) { - return false; + actualTolerance = MathF.Max(actualTolerance, scaledTolerance); } - tolerance ??= DefaultFloatTolerance; - float diff = MathF.Abs(expected - actual); - return !(diff > tolerance && diff > MathF.Max(MathF.Abs(expected), MathF.Abs(actual)) * tolerance); + AssertExtensions.Equal(expected, actual, actualTolerance, banner); } #endif @@ -82,13 +129,13 @@ public static bool IsEqualWithTolerance(float expected, float actual, float? tol } else if (typeof(T) == typeof(NFloat)) { - if (IntPtr.Size == 8 && doubleTolerance != null) + if (NFloat.Size == 8 && doubleTolerance != null) { return (T?)(object)(NFloat)doubleTolerance; } - else if (IntPtr.Size == 4 && floatTolerance != null) + else if (NFloat.Size == 4 && floatTolerance != null) { - return (T?)(object)(NFloat)doubleTolerance; + return (T?)(object)(NFloat)floatTolerance; } } #endif diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs index 2cf563faa56bb4..8d40454f9b9ade 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs @@ -96,6 +96,7 @@ private static void ConvertTruncatingImpl() where TFrom : unmanaged, INumber where TTo : unmanaged, INumber { + string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}"; AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertTruncating(new TFrom[3], new TTo[2])); Random rand = new(42); @@ -116,10 +117,7 @@ private static void ConvertTruncatingImpl() for (int i = 0; i < tensorLength; i++) { - if (!Helpers.IsEqualWithTolerance(TTo.CreateTruncating(source.Span[i]), destination.Span[i])) - { - throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Actual: {destination.Span[i]}. Expected: {TTo.CreateTruncating(source.Span[i])}."); - } + Helpers.AssertEqualWithTolerance(TTo.CreateTruncating(source.Span[i]), destination.Span[i], banner: banner); } } } @@ -128,6 +126,7 @@ private static void ConvertSaturatingImpl() where TFrom : unmanaged, INumber where TTo : unmanaged, INumber { + string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}"; AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertSaturating(new TFrom[3], new TTo[2])); Random rand = new(42); @@ -148,10 +147,7 @@ private static void ConvertSaturatingImpl() for (int i = 0; i < tensorLength; i++) { - if (!Helpers.IsEqualWithTolerance(TTo.CreateSaturating(source.Span[i]), destination.Span[i])) - { - throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Actual: {destination.Span[i]}. Expected: {TTo.CreateSaturating(source.Span[i])}."); - } + Helpers.AssertEqualWithTolerance(TTo.CreateSaturating(source.Span[i]), destination.Span[i], banner: banner); } } } @@ -160,6 +156,7 @@ private static void ConvertCheckedImpl() where TFrom : unmanaged, INumber where TTo : unmanaged, INumber { + string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}"; AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertChecked(new TFrom[3], new TTo[2])); foreach (int tensorLength in Helpers.TensorLengthsIncluding0) @@ -180,10 +177,7 @@ private static void ConvertCheckedImpl() for (int i = 0; i < tensorLength; i++) { - if (!Helpers.IsEqualWithTolerance(TTo.CreateChecked(source.Span[i]), destination.Span[i])) - { - throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Actual: {destination.Span[i]}. Expected: {TTo.CreateChecked(source.Span[i])}."); - } + Helpers.AssertEqualWithTolerance(TTo.CreateChecked(source.Span[i]), destination.Span[i], banner: banner); } } } @@ -192,6 +186,8 @@ private static void ConvertCheckedImpl(TFrom valid, TFrom invalid) where TFrom : unmanaged, INumber where TTo : unmanaged, INumber { + string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}"; + foreach (int tensorLength in Helpers.TensorLengths) { using BoundedMemory source = BoundedMemory.Allocate(tensorLength); @@ -202,7 +198,7 @@ private static void ConvertCheckedImpl(TFrom valid, TFrom invalid) TensorPrimitives.ConvertChecked(source.Span, destination.Span); foreach (TTo result in destination.Span) { - Assert.True(Helpers.IsEqualWithTolerance(TTo.CreateChecked(valid), result)); + Helpers.AssertEqualWithTolerance(TTo.CreateChecked(valid), result, banner: banner); } // Test with at least one invalid @@ -258,6 +254,7 @@ private static void ConvertToIntegerImpl() where TFrom : unmanaged, IFloatingPoint where TTo : unmanaged, IBinaryInteger { + string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}"; AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToInteger(new TFrom[3], new TTo[2])); Random rand = new(42); @@ -277,10 +274,7 @@ private static void ConvertToIntegerImpl() for (int i = 0; i < tensorLength; i++) { TTo expected = TFrom.ConvertToInteger(source.Span[i]); - if (!Helpers.IsEqualWithTolerance(expected, destination.Span[i])) - { - throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Expected: {expected}. Actual: {destination.Span[i]}."); - } + Helpers.AssertEqualWithTolerance(expected, destination.Span[i], banner: banner); } } } @@ -289,6 +283,7 @@ private static void ConvertToIntegerNativeImpl() where TFrom : unmanaged, IFloatingPoint where TTo : unmanaged, IBinaryInteger { + string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}"; AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToIntegerNative(new TFrom[3], new TTo[2])); Random rand = new(42); @@ -308,10 +303,7 @@ private static void ConvertToIntegerNativeImpl() for (int i = 0; i < tensorLength; i++) { TTo expected = TFrom.ConvertToIntegerNative(source.Span[i]); - if (!Helpers.IsEqualWithTolerance(expected, destination.Span[i])) - { - throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Expected: {expected}. Actual: {destination.Span[i]}."); - } + Helpers.AssertEqualWithTolerance(expected, destination.Span[i], banner: banner); } } } @@ -2752,10 +2744,7 @@ protected override T NextRandom() protected override void AssertEqualTolerance(T expected, T actual, T? tolerance = null) { - if (!Helpers.IsEqualWithTolerance(expected, actual, tolerance)) - { - throw EqualException.ForMismatchedValues($"{expected}", $"{actual}"); - } + Helpers.AssertEqualWithTolerance(expected, actual, tolerance); } protected override T Cosh(T x) => throw new NotSupportedException(); diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs index 707e27d2fae419..0bef82de4ce557 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs @@ -106,10 +106,7 @@ protected override float MinMagnitude(float x, float y) protected override void AssertEqualTolerance(float expected, float actual, float? tolerance = null) { - if (!Helpers.IsEqualWithTolerance(expected, actual, tolerance)) - { - throw EqualException.ForMismatchedValues(expected.ToString(), actual.ToString()); - } + Helpers.AssertEqualWithTolerance(expected, actual, tolerance); } protected override IEnumerable GetSpecialValues()