diff --git a/src/libraries/Common/tests/TestUtilities/System/AssertExtensions.cs b/src/libraries/Common/tests/TestUtilities/System/AssertExtensions.cs
index 880bfcb3831def..cc3de8dba7d995 100644
--- a/src/libraries/Common/tests/TestUtilities/System/AssertExtensions.cs
+++ b/src/libraries/Common/tests/TestUtilities/System/AssertExtensions.cs
@@ -1190,6 +1190,25 @@ public static void Equal(Half expected, Half actual, Half variance, string? bann
throw EqualException.ForMismatchedValues(ToStringPadded(expected), ToStringPadded(actual), banner);
}
}
+
+ /// Verifies that two values are equal, within the .
+ /// The expected value
+ /// The value to be compared against
+ /// The total variance allowed between the expected and actual results.
+ /// The banner to show; if null, then the standard
+ /// banner of "Values differ" will be used
+ /// Thrown when the values are not equal
+ public static void Equal(NFloat expected, NFloat actual, NFloat variance, string? banner = null)
+ {
+ if (NFloat.Size == 4)
+ {
+ Equal((float)expected, (float)actual, (float)variance, banner);
+ }
+ else
+ {
+ Equal((double)expected, (double)actual, (double)variance, banner);
+ }
+ }
#endif
/// Verifies that two values's binary representations are identical.
@@ -1257,6 +1276,22 @@ public static void Equal(Half expected, Half actual)
throw EqualException.ForMismatchedValues(ToStringPadded(expected), ToStringPadded(actual));
}
+
+ /// Verifies that two values's binary representations are identical.
+ /// The expected value
+ /// The value to be compared against
+ /// Thrown when the representations are not identical
+ public static void Equal(NFloat expected, NFloat actual)
+ {
+ if (NFloat.Size == 4)
+ {
+ Equal((float)expected, (float)actual);
+ }
+ else
+ {
+ Equal((double)expected, (double)actual);
+ }
+ }
#endif
}
}
diff --git a/src/libraries/System.Numerics.Tensors/tests/Helpers.cs b/src/libraries/System.Numerics.Tensors/tests/Helpers.cs
index 27a32a3b27a0ae..d3ebdecefa92d9 100644
--- a/src/libraries/System.Numerics.Tensors/tests/Helpers.cs
+++ b/src/libraries/System.Numerics.Tensors/tests/Helpers.cs
@@ -5,6 +5,7 @@
using System.Linq;
using System.Runtime.InteropServices;
using Xunit;
+using Xunit.Sdk;
namespace System.Numerics.Tensors.Tests
{
@@ -34,28 +35,74 @@ private static class DefaultTolerance where T : unmanaged, INumber
public static readonly T Value = DetermineTolerance(DefaultDoubleTolerance, DefaultFloatTolerance, Half.CreateTruncating(DefaultHalfTolerance)) ?? T.CreateTruncating(0);
}
- public static bool IsEqualWithTolerance(T expected, T actual, T? tolerance = null) where T : unmanaged, INumber
+ public static void AssertEqualWithTolerance(T expected, T actual, T? tolerance = null, string? banner = null) where T : unmanaged, INumber
{
- if (T.IsNaN(expected) != T.IsNaN(actual))
+ T actualTolerance = tolerance ?? DefaultTolerance.Value;
+ try
{
- return false;
+ T scaledTolerance = checked(T.Max(T.Abs(expected), T.Abs(actual)) * actualTolerance);
+ if (T.IsFinite(scaledTolerance))
+ {
+ actualTolerance = T.Max(scaledTolerance, actualTolerance);
+ }
}
+ catch (OverflowException) { } // Multiplication and T.Abs can throw for integers, just keep the original tolerance in that case.
- tolerance = tolerance ?? DefaultTolerance.Value;
- T diff = T.Abs(expected - actual);
- return !(diff > tolerance && diff > T.Max(T.Abs(expected), T.Abs(actual)) * tolerance);
+ // Delegate to AssertExtensions.Equal for special value comparisons (NaN, +-inf, +-0)
+ if (typeof(T) == typeof(double))
+ {
+ AssertExtensions.Equal((double)(object)expected, (double)(object)actual, (double)(object)actualTolerance, banner);
+ }
+ else if (typeof(T) == typeof(float))
+ {
+ AssertExtensions.Equal((float)(object)expected, (float)(object)actual, (float)(object)actualTolerance, banner);
+ }
+ else if (typeof(T) == typeof(Half))
+ {
+ AssertExtensions.Equal((Half)(object)expected, (Half)(object)actual, (Half)(object)actualTolerance, banner);
+ }
+ else if (typeof(T) == typeof(NFloat))
+ {
+ AssertExtensions.Equal((NFloat)(object)expected, (NFloat)(object)actual, (NFloat)(object)actualTolerance, banner);
+ }
+ else if (typeof(T) == typeof(sbyte) || typeof(T) == typeof(byte) ||
+ typeof(T) == typeof(short) || typeof(T) == typeof(ushort) || typeof(T) == typeof(char) ||
+ typeof(T) == typeof(int) || typeof(T) == typeof(uint) ||
+ typeof(T) == typeof(long) || typeof(T) == typeof(ulong) ||
+ typeof(T) == typeof(nint) || typeof(T) == typeof(nuint) ||
+ typeof(T) == typeof(Int128) || typeof(T) == typeof(UInt128))
+ {
+ T delta;
+ try
+ {
+ delta = T.Abs(checked(expected - actual));
+ }
+ catch (OverflowException)
+ {
+ // Subtraction and T.Abs can throw for integers, in that case the mismatch is large enough to fail assertion
+ throw EqualException.ForMismatchedValues(expected.ToString(), actual.ToString(), banner);
+ }
+ if (delta > actualTolerance)
+ {
+ throw EqualException.ForMismatchedValues(expected.ToString(), actual.ToString(), banner);
+ }
+ }
+ else
+ {
+ throw new NotImplementedException($"Type not supported for {nameof(AssertEqualWithTolerance)}: {typeof(T).Name}");
+ }
}
#else
- public static bool IsEqualWithTolerance(float expected, float actual, float? tolerance = null)
+ public static void AssertEqualWithTolerance(float expected, float actual, float? tolerance = null, string? banner = null)
{
- if (float.IsNaN(expected) != float.IsNaN(actual))
+ float actualTolerance = tolerance ?? DefaultFloatTolerance;
+ float scaledTolerance = MathF.Max(MathF.Abs(expected), MathF.Abs(actual)) * (tolerance ?? DefaultFloatTolerance);
+ if (!float.IsNaN(scaledTolerance) && !float.IsInfinity(scaledTolerance))
{
- return false;
+ actualTolerance = MathF.Max(actualTolerance, scaledTolerance);
}
- tolerance ??= DefaultFloatTolerance;
- float diff = MathF.Abs(expected - actual);
- return !(diff > tolerance && diff > MathF.Max(MathF.Abs(expected), MathF.Abs(actual)) * tolerance);
+ AssertExtensions.Equal(expected, actual, actualTolerance, banner);
}
#endif
@@ -82,13 +129,13 @@ public static bool IsEqualWithTolerance(float expected, float actual, float? tol
}
else if (typeof(T) == typeof(NFloat))
{
- if (IntPtr.Size == 8 && doubleTolerance != null)
+ if (NFloat.Size == 8 && doubleTolerance != null)
{
return (T?)(object)(NFloat)doubleTolerance;
}
- else if (IntPtr.Size == 4 && floatTolerance != null)
+ else if (NFloat.Size == 4 && floatTolerance != null)
{
- return (T?)(object)(NFloat)doubleTolerance;
+ return (T?)(object)(NFloat)floatTolerance;
}
}
#endif
diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs
index 2cf563faa56bb4..8d40454f9b9ade 100644
--- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs
+++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs
@@ -96,6 +96,7 @@ private static void ConvertTruncatingImpl()
where TFrom : unmanaged, INumber
where TTo : unmanaged, INumber
{
+ string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}";
AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertTruncating(new TFrom[3], new TTo[2]));
Random rand = new(42);
@@ -116,10 +117,7 @@ private static void ConvertTruncatingImpl()
for (int i = 0; i < tensorLength; i++)
{
- if (!Helpers.IsEqualWithTolerance(TTo.CreateTruncating(source.Span[i]), destination.Span[i]))
- {
- throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Actual: {destination.Span[i]}. Expected: {TTo.CreateTruncating(source.Span[i])}.");
- }
+ Helpers.AssertEqualWithTolerance(TTo.CreateTruncating(source.Span[i]), destination.Span[i], banner: banner);
}
}
}
@@ -128,6 +126,7 @@ private static void ConvertSaturatingImpl()
where TFrom : unmanaged, INumber
where TTo : unmanaged, INumber
{
+ string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}";
AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertSaturating(new TFrom[3], new TTo[2]));
Random rand = new(42);
@@ -148,10 +147,7 @@ private static void ConvertSaturatingImpl()
for (int i = 0; i < tensorLength; i++)
{
- if (!Helpers.IsEqualWithTolerance(TTo.CreateSaturating(source.Span[i]), destination.Span[i]))
- {
- throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Actual: {destination.Span[i]}. Expected: {TTo.CreateSaturating(source.Span[i])}.");
- }
+ Helpers.AssertEqualWithTolerance(TTo.CreateSaturating(source.Span[i]), destination.Span[i], banner: banner);
}
}
}
@@ -160,6 +156,7 @@ private static void ConvertCheckedImpl()
where TFrom : unmanaged, INumber
where TTo : unmanaged, INumber
{
+ string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}";
AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertChecked(new TFrom[3], new TTo[2]));
foreach (int tensorLength in Helpers.TensorLengthsIncluding0)
@@ -180,10 +177,7 @@ private static void ConvertCheckedImpl()
for (int i = 0; i < tensorLength; i++)
{
- if (!Helpers.IsEqualWithTolerance(TTo.CreateChecked(source.Span[i]), destination.Span[i]))
- {
- throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Actual: {destination.Span[i]}. Expected: {TTo.CreateChecked(source.Span[i])}.");
- }
+ Helpers.AssertEqualWithTolerance(TTo.CreateChecked(source.Span[i]), destination.Span[i], banner: banner);
}
}
}
@@ -192,6 +186,8 @@ private static void ConvertCheckedImpl(TFrom valid, TFrom invalid)
where TFrom : unmanaged, INumber
where TTo : unmanaged, INumber
{
+ string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}";
+
foreach (int tensorLength in Helpers.TensorLengths)
{
using BoundedMemory source = BoundedMemory.Allocate(tensorLength);
@@ -202,7 +198,7 @@ private static void ConvertCheckedImpl(TFrom valid, TFrom invalid)
TensorPrimitives.ConvertChecked(source.Span, destination.Span);
foreach (TTo result in destination.Span)
{
- Assert.True(Helpers.IsEqualWithTolerance(TTo.CreateChecked(valid), result));
+ Helpers.AssertEqualWithTolerance(TTo.CreateChecked(valid), result, banner: banner);
}
// Test with at least one invalid
@@ -258,6 +254,7 @@ private static void ConvertToIntegerImpl()
where TFrom : unmanaged, IFloatingPoint
where TTo : unmanaged, IBinaryInteger
{
+ string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}";
AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToInteger(new TFrom[3], new TTo[2]));
Random rand = new(42);
@@ -277,10 +274,7 @@ private static void ConvertToIntegerImpl()
for (int i = 0; i < tensorLength; i++)
{
TTo expected = TFrom.ConvertToInteger(source.Span[i]);
- if (!Helpers.IsEqualWithTolerance(expected, destination.Span[i]))
- {
- throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Expected: {expected}. Actual: {destination.Span[i]}.");
- }
+ Helpers.AssertEqualWithTolerance(expected, destination.Span[i], banner: banner);
}
}
}
@@ -289,6 +283,7 @@ private static void ConvertToIntegerNativeImpl()
where TFrom : unmanaged, IFloatingPoint
where TTo : unmanaged, IBinaryInteger
{
+ string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}";
AssertExtensions.Throws("destination", () => TensorPrimitives.ConvertToIntegerNative(new TFrom[3], new TTo[2]));
Random rand = new(42);
@@ -308,10 +303,7 @@ private static void ConvertToIntegerNativeImpl()
for (int i = 0; i < tensorLength; i++)
{
TTo expected = TFrom.ConvertToIntegerNative(source.Span[i]);
- if (!Helpers.IsEqualWithTolerance(expected, destination.Span[i]))
- {
- throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Expected: {expected}. Actual: {destination.Span[i]}.");
- }
+ Helpers.AssertEqualWithTolerance(expected, destination.Span[i], banner: banner);
}
}
}
@@ -2752,10 +2744,7 @@ protected override T NextRandom()
protected override void AssertEqualTolerance(T expected, T actual, T? tolerance = null)
{
- if (!Helpers.IsEqualWithTolerance(expected, actual, tolerance))
- {
- throw EqualException.ForMismatchedValues($"{expected}", $"{actual}");
- }
+ Helpers.AssertEqualWithTolerance(expected, actual, tolerance);
}
protected override T Cosh(T x) => throw new NotSupportedException();
diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs
index 707e27d2fae419..0bef82de4ce557 100644
--- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs
+++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.NonGeneric.Single.cs
@@ -106,10 +106,7 @@ protected override float MinMagnitude(float x, float y)
protected override void AssertEqualTolerance(float expected, float actual, float? tolerance = null)
{
- if (!Helpers.IsEqualWithTolerance(expected, actual, tolerance))
- {
- throw EqualException.ForMismatchedValues(expected.ToString(), actual.ToString());
- }
+ Helpers.AssertEqualWithTolerance(expected, actual, tolerance);
}
protected override IEnumerable GetSpecialValues()