Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,25 @@ public static void Equal(Half expected, Half actual, Half variance, string? bann
throw EqualException.ForMismatchedValues(ToStringPadded(expected), ToStringPadded(actual), banner);
}
}

/// <summary>Verifies that two <see cref="NFloat"/> values are equal, within the <paramref name="variance"/>.</summary>
/// <param name="expected">The expected value</param>
/// <param name="actual">The value to be compared against</param>
/// <param name="variance">The total variance allowed between the expected and actual results.</param>
/// <param name="banner">The banner to show; if <c>null</c>, then the standard
/// banner of "Values differ" will be used</param>
/// <exception cref="EqualException">Thrown when the values are not equal</exception>
public static void Equal(NFloat expected, NFloat actual, NFloat variance, string? banner = null)
{
if (NFloat.Size == 4)
{
Equal((float)expected, (float)actual, (float)variance, banner);
}
else
{
Equal((double)expected, (double)actual, (double)variance, banner);
}
}
#endif

/// <summary>Verifies that two <see cref="double"/> values's binary representations are identical.</summary>
Expand Down Expand Up @@ -1257,6 +1276,22 @@ public static void Equal(Half expected, Half actual)

throw EqualException.ForMismatchedValues(ToStringPadded(expected), ToStringPadded(actual));
}

/// <summary>Verifies that two <see cref="NFloat"/> values's binary representations are identical.</summary>
/// <param name="expected">The expected value</param>
/// <param name="actual">The value to be compared against</param>
/// <exception cref="EqualException">Thrown when the representations are not identical</exception>
public static void Equal(NFloat expected, NFloat actual)
{
if (NFloat.Size == 4)
{
Equal((float)expected, (float)actual);
}
else
{
Equal((double)expected, (double)actual);
}
}
#endif
}
}
77 changes: 62 additions & 15 deletions src/libraries/System.Numerics.Tensors/tests/Helpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Linq;
using System.Runtime.InteropServices;
using Xunit;
using Xunit.Sdk;

namespace System.Numerics.Tensors.Tests
{
Expand Down Expand Up @@ -34,28 +35,74 @@ private static class DefaultTolerance<T> where T : unmanaged, INumber<T>
public static readonly T Value = DetermineTolerance<T>(DefaultDoubleTolerance, DefaultFloatTolerance, Half.CreateTruncating(DefaultHalfTolerance)) ?? T.CreateTruncating(0);
}

public static bool IsEqualWithTolerance<T>(T expected, T actual, T? tolerance = null) where T : unmanaged, INumber<T>
public static void AssertEqualWithTolerance<T>(T expected, T actual, T? tolerance = null, string? banner = null) where T : unmanaged, INumber<T>
{
if (T.IsNaN(expected) != T.IsNaN(actual))
T actualTolerance = tolerance ?? DefaultTolerance<T>.Value;
try
{
return false;
T scaledTolerance = checked(T.Max(T.Abs(expected), T.Abs(actual)) * actualTolerance);
if (T.IsFinite(scaledTolerance))
{
actualTolerance = T.Max(scaledTolerance, actualTolerance);
}
}
Comment thread
lilinus marked this conversation as resolved.
catch (OverflowException) { } // Multiplication and T.Abs can throw for integers, just keep the original tolerance in that case.

tolerance = tolerance ?? DefaultTolerance<T>.Value;
T diff = T.Abs(expected - actual);
return !(diff > tolerance && diff > T.Max(T.Abs(expected), T.Abs(actual)) * tolerance);
// Delegate to AssertExtensions.Equal for special value comparisons (NaN, +-inf, +-0)
if (typeof(T) == typeof(double))
{
AssertExtensions.Equal((double)(object)expected, (double)(object)actual, (double)(object)actualTolerance, banner);
}
else if (typeof(T) == typeof(float))
{
AssertExtensions.Equal((float)(object)expected, (float)(object)actual, (float)(object)actualTolerance, banner);
}
else if (typeof(T) == typeof(Half))
{
AssertExtensions.Equal((Half)(object)expected, (Half)(object)actual, (Half)(object)actualTolerance, banner);
}
else if (typeof(T) == typeof(NFloat))
{
AssertExtensions.Equal((NFloat)(object)expected, (NFloat)(object)actual, (NFloat)(object)actualTolerance, banner);
}
else if (typeof(T) == typeof(sbyte) || typeof(T) == typeof(byte) ||
typeof(T) == typeof(short) || typeof(T) == typeof(ushort) || typeof(T) == typeof(char) ||
typeof(T) == typeof(int) || typeof(T) == typeof(uint) ||
typeof(T) == typeof(long) || typeof(T) == typeof(ulong) ||
typeof(T) == typeof(nint) || typeof(T) == typeof(nuint) ||
typeof(T) == typeof(Int128) || typeof(T) == typeof(UInt128))
{
T delta;
try
{
delta = T.Abs(checked(expected - actual));
}
catch (OverflowException)
{
// Subtraction and T.Abs can throw for integers, in that case the mismatch is large enough to fail assertion
throw EqualException.ForMismatchedValues(expected.ToString(), actual.ToString(), banner);
}
if (delta > actualTolerance)
{
throw EqualException.ForMismatchedValues(expected.ToString(), actual.ToString(), banner);
}
}
Comment thread
lilinus marked this conversation as resolved.
else
{
throw new NotImplementedException($"Type not supported for {nameof(AssertEqualWithTolerance)}: {typeof(T).Name}");
}
}
#else
public static bool IsEqualWithTolerance(float expected, float actual, float? tolerance = null)
public static void AssertEqualWithTolerance(float expected, float actual, float? tolerance = null, string? banner = null)
{
if (float.IsNaN(expected) != float.IsNaN(actual))
float actualTolerance = tolerance ?? DefaultFloatTolerance;
float scaledTolerance = MathF.Max(MathF.Abs(expected), MathF.Abs(actual)) * (tolerance ?? DefaultFloatTolerance);
if (!float.IsNaN(scaledTolerance) && !float.IsInfinity(scaledTolerance))
{
return false;
actualTolerance = MathF.Max(actualTolerance, scaledTolerance);
}
Comment thread
lilinus marked this conversation as resolved.

tolerance ??= DefaultFloatTolerance;
float diff = MathF.Abs(expected - actual);
return !(diff > tolerance && diff > MathF.Max(MathF.Abs(expected), MathF.Abs(actual)) * tolerance);
AssertExtensions.Equal(expected, actual, actualTolerance, banner);
}
#endif

Expand All @@ -82,13 +129,13 @@ public static bool IsEqualWithTolerance(float expected, float actual, float? tol
}
else if (typeof(T) == typeof(NFloat))
{
if (IntPtr.Size == 8 && doubleTolerance != null)
if (NFloat.Size == 8 && doubleTolerance != null)
{
return (T?)(object)(NFloat)doubleTolerance;
}
else if (IntPtr.Size == 4 && floatTolerance != null)
else if (NFloat.Size == 4 && floatTolerance != null)
{
return (T?)(object)(NFloat)doubleTolerance;
return (T?)(object)(NFloat)floatTolerance;
}
}
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ private static void ConvertTruncatingImpl<TFrom, TTo>()
where TFrom : unmanaged, INumber<TFrom>
where TTo : unmanaged, INumber<TTo>
{
string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}";
AssertExtensions.Throws<ArgumentException>("destination", () => TensorPrimitives.ConvertTruncating<TFrom, TTo>(new TFrom[3], new TTo[2]));

Random rand = new(42);
Expand All @@ -116,10 +117,7 @@ private static void ConvertTruncatingImpl<TFrom, TTo>()

for (int i = 0; i < tensorLength; i++)
{
if (!Helpers.IsEqualWithTolerance(TTo.CreateTruncating(source.Span[i]), destination.Span[i]))
{
throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Actual: {destination.Span[i]}. Expected: {TTo.CreateTruncating(source.Span[i])}.");
}
Helpers.AssertEqualWithTolerance(TTo.CreateTruncating(source.Span[i]), destination.Span[i], banner: banner);
}
}
}
Expand All @@ -128,6 +126,7 @@ private static void ConvertSaturatingImpl<TFrom, TTo>()
where TFrom : unmanaged, INumber<TFrom>
where TTo : unmanaged, INumber<TTo>
{
string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}";
AssertExtensions.Throws<ArgumentException>("destination", () => TensorPrimitives.ConvertSaturating<TFrom, TTo>(new TFrom[3], new TTo[2]));

Random rand = new(42);
Expand All @@ -148,10 +147,7 @@ private static void ConvertSaturatingImpl<TFrom, TTo>()

for (int i = 0; i < tensorLength; i++)
{
if (!Helpers.IsEqualWithTolerance(TTo.CreateSaturating(source.Span[i]), destination.Span[i]))
{
throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Actual: {destination.Span[i]}. Expected: {TTo.CreateSaturating(source.Span[i])}.");
}
Helpers.AssertEqualWithTolerance(TTo.CreateSaturating(source.Span[i]), destination.Span[i], banner: banner);
}
}
}
Expand All @@ -160,6 +156,7 @@ private static void ConvertCheckedImpl<TFrom, TTo>()
where TFrom : unmanaged, INumber<TFrom>
where TTo : unmanaged, INumber<TTo>
{
string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}";
AssertExtensions.Throws<ArgumentException>("destination", () => TensorPrimitives.ConvertChecked<TFrom, TTo>(new TFrom[3], new TTo[2]));

foreach (int tensorLength in Helpers.TensorLengthsIncluding0)
Expand All @@ -180,10 +177,7 @@ private static void ConvertCheckedImpl<TFrom, TTo>()

for (int i = 0; i < tensorLength; i++)
{
if (!Helpers.IsEqualWithTolerance(TTo.CreateChecked(source.Span[i]), destination.Span[i]))
{
throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Actual: {destination.Span[i]}. Expected: {TTo.CreateChecked(source.Span[i])}.");
}
Helpers.AssertEqualWithTolerance(TTo.CreateChecked(source.Span[i]), destination.Span[i], banner: banner);
}
}
}
Expand All @@ -192,6 +186,8 @@ private static void ConvertCheckedImpl<TFrom, TTo>(TFrom valid, TFrom invalid)
where TFrom : unmanaged, INumber<TFrom>
where TTo : unmanaged, INumber<TTo>
{
string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}";

foreach (int tensorLength in Helpers.TensorLengths)
{
using BoundedMemory<TFrom> source = BoundedMemory.Allocate<TFrom>(tensorLength);
Expand All @@ -202,7 +198,7 @@ private static void ConvertCheckedImpl<TFrom, TTo>(TFrom valid, TFrom invalid)
TensorPrimitives.ConvertChecked<TFrom, TTo>(source.Span, destination.Span);
foreach (TTo result in destination.Span)
{
Assert.True(Helpers.IsEqualWithTolerance(TTo.CreateChecked(valid), result));
Helpers.AssertEqualWithTolerance(TTo.CreateChecked(valid), result, banner: banner);
}
Comment thread
lilinus marked this conversation as resolved.

// Test with at least one invalid
Expand Down Expand Up @@ -258,6 +254,7 @@ private static void ConvertToIntegerImpl<TFrom, TTo>()
where TFrom : unmanaged, IFloatingPoint<TFrom>
where TTo : unmanaged, IBinaryInteger<TTo>
{
string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}";
AssertExtensions.Throws<ArgumentException>("destination", () => TensorPrimitives.ConvertToInteger<TFrom, TTo>(new TFrom[3], new TTo[2]));

Random rand = new(42);
Expand All @@ -277,10 +274,7 @@ private static void ConvertToIntegerImpl<TFrom, TTo>()
for (int i = 0; i < tensorLength; i++)
{
TTo expected = TFrom.ConvertToInteger<TTo>(source.Span[i]);
if (!Helpers.IsEqualWithTolerance(expected, destination.Span[i]))
{
throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Expected: {expected}. Actual: {destination.Span[i]}.");
}
Helpers.AssertEqualWithTolerance(expected, destination.Span[i], banner: banner);
}
}
}
Expand All @@ -289,6 +283,7 @@ private static void ConvertToIntegerNativeImpl<TFrom, TTo>()
where TFrom : unmanaged, IFloatingPoint<TFrom>
where TTo : unmanaged, IBinaryInteger<TTo>
{
string banner = $"{typeof(TFrom).Name} => {typeof(TTo).Name}";
AssertExtensions.Throws<ArgumentException>("destination", () => TensorPrimitives.ConvertToIntegerNative<TFrom, TTo>(new TFrom[3], new TTo[2]));

Random rand = new(42);
Expand All @@ -308,10 +303,7 @@ private static void ConvertToIntegerNativeImpl<TFrom, TTo>()
for (int i = 0; i < tensorLength; i++)
{
TTo expected = TFrom.ConvertToIntegerNative<TTo>(source.Span[i]);
if (!Helpers.IsEqualWithTolerance(expected, destination.Span[i]))
{
throw new XunitException($"{typeof(TFrom).Name} => {typeof(TTo).Name}. Input: {source.Span[i]}. Expected: {expected}. Actual: {destination.Span[i]}.");
}
Helpers.AssertEqualWithTolerance(expected, destination.Span[i], banner: banner);
}
}
}
Expand Down Expand Up @@ -2752,10 +2744,7 @@ protected override T NextRandom()

protected override void AssertEqualTolerance(T expected, T actual, T? tolerance = null)
{
if (!Helpers.IsEqualWithTolerance(expected, actual, tolerance))
{
throw EqualException.ForMismatchedValues($"{expected}", $"{actual}");
}
Helpers.AssertEqualWithTolerance(expected, actual, tolerance);
}

protected override T Cosh(T x) => throw new NotSupportedException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,7 @@ protected override float MinMagnitude(float x, float y)

protected override void AssertEqualTolerance(float expected, float actual, float? tolerance = null)
{
if (!Helpers.IsEqualWithTolerance(expected, actual, tolerance))
{
throw EqualException.ForMismatchedValues(expected.ToString(), actual.ToString());
}
Helpers.AssertEqualWithTolerance(expected, actual, tolerance);
}

protected override IEnumerable<float> GetSpecialValues()
Expand Down
Loading