Skip to content

Commit bdd7b7a

Browse files
Vectorize TensorPrimitives.Exp (#93018)
* Vectorize TensorPrimitives.Exp * Update src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs
1 parent b4bb155 commit bdd7b7a

3 files changed

Lines changed: 295 additions & 14 deletions

File tree

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.cs

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -322,20 +322,8 @@ public static float Dot(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
322322
/// operating systems or architectures.
323323
/// </para>
324324
/// </remarks>
325-
public static void Exp(ReadOnlySpan<float> x, Span<float> destination)
326-
{
327-
if (x.Length > destination.Length)
328-
{
329-
ThrowHelper.ThrowArgument_DestinationTooShort();
330-
}
331-
332-
ValidateInputOutputSpanNonOverlapping(x, destination);
333-
334-
for (int i = 0; i < x.Length; i++)
335-
{
336-
destination[i] = MathF.Exp(x[i]);
337-
}
338-
}
325+
public static void Exp(ReadOnlySpan<float> x, Span<float> destination) =>
326+
InvokeSpanIntoSpan<ExpOperator>(x, destination);
339327

340328
/// <summary>Searches for the index of the largest single-precision floating-point number in the specified tensor.</summary>
341329
/// <param name="x">The tensor, represented as a span.</param>

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2579,6 +2579,286 @@ public static Vector512<float> Invoke(Vector512<float> x, Vector512<float> y)
25792579
#endif
25802580
}
25812581

2582+
private readonly struct ExpOperator : IUnaryOperator
2583+
{
2584+
// This code is based on `vrs4_expf` from amd/aocl-libm-ose
2585+
// Copyright (C) 2019-2022 Advanced Micro Devices, Inc. All rights reserved.
2586+
//
2587+
// Licensed under the BSD 3-Clause "New" or "Revised" License
2588+
// See THIRD-PARTY-NOTICES.TXT for the full license text
2589+
2590+
// Implementation Notes:
2591+
// 1. Argument Reduction:
2592+
// e^x = 2^(x/ln2) --- (1)
2593+
//
2594+
// Let x/ln(2) = z --- (2)
2595+
//
2596+
// Let z = n + r , where n is an integer --- (3)
2597+
// |r| <= 1/2
2598+
//
2599+
// From (1), (2) and (3),
2600+
// e^x = 2^z
2601+
// = 2^(N+r)
2602+
// = (2^N)*(2^r) --- (4)
2603+
//
2604+
// 2. Polynomial Evaluation
2605+
// From (4),
2606+
// r = z - N
2607+
// 2^r = C1 + C2*r + C3*r^2 + C4*r^3 + C5 *r^4 + C6*r^5
2608+
//
2609+
// 4. Reconstruction
2610+
// Thus,
2611+
// e^x = (2^N) * (2^r)
2612+
2613+
private const uint V_ARG_MAX = 0x42AE0000;
2614+
private const uint V_MASK = 0x7FFFFFFF;
2615+
2616+
private const float V_EXPF_MIN = -103.97208f;
2617+
private const float V_EXPF_MAX = 88.72284f;
2618+
2619+
private const double V_EXPF_HUGE = 6755399441055744;
2620+
private const double V_TBL_LN2 = 1.4426950408889634;
2621+
2622+
private const double C1 = 1.0000000754895704;
2623+
private const double C2 = 0.6931472254087585;
2624+
private const double C3 = 0.2402210737432219;
2625+
private const double C4 = 0.05550297297702539;
2626+
private const double C5 = 0.009676036358193323;
2627+
private const double C6 = 0.001341000536524434;
2628+
2629+
public static float Invoke(float x) => MathF.Exp(x);
2630+
2631+
public static Vector128<float> Invoke(Vector128<float> x)
2632+
{
2633+
// Convert x to double precision
2634+
(Vector128<double> xl, Vector128<double> xu) = Vector128.Widen(x);
2635+
2636+
// x * (64.0 / ln(2))
2637+
Vector128<double> v_tbl_ln2 = Vector128.Create(V_TBL_LN2);
2638+
2639+
Vector128<double> zl = xl * v_tbl_ln2;
2640+
Vector128<double> zu = xu * v_tbl_ln2;
2641+
2642+
Vector128<double> v_expf_huge = Vector128.Create(V_EXPF_HUGE);
2643+
2644+
Vector128<double> dnl = zl + v_expf_huge;
2645+
Vector128<double> dnu = zu + v_expf_huge;
2646+
2647+
// n = int (z)
2648+
Vector128<ulong> nl = dnl.AsUInt64();
2649+
Vector128<ulong> nu = dnu.AsUInt64();
2650+
2651+
// dn = double(n)
2652+
dnl -= v_expf_huge;
2653+
dnu -= v_expf_huge;
2654+
2655+
// r = z - dn
2656+
Vector128<double> c1 = Vector128.Create(C1);
2657+
Vector128<double> c2 = Vector128.Create(C2);
2658+
Vector128<double> c3 = Vector128.Create(C3);
2659+
Vector128<double> c4 = Vector128.Create(C4);
2660+
Vector128<double> c5 = Vector128.Create(C5);
2661+
Vector128<double> c6 = Vector128.Create(C6);
2662+
2663+
Vector128<double> rl = zl - dnl;
2664+
2665+
Vector128<double> rl2 = rl * rl;
2666+
Vector128<double> rl4 = rl2 * rl2;
2667+
2668+
Vector128<double> polyl = (c4 * rl + c3) * rl2
2669+
+ ((c6 * rl + c5) * rl4
2670+
+ (c2 * rl + c1));
2671+
2672+
2673+
Vector128<double> ru = zu - dnu;
2674+
2675+
Vector128<double> ru2 = ru * ru;
2676+
Vector128<double> ru4 = ru2 * ru2;
2677+
2678+
Vector128<double> polyu = (c4 * ru + c3) * ru2
2679+
+ ((c6 * ru + c5) * ru4
2680+
+ (c2 * ru + c1));
2681+
2682+
// result = (float)[poly + (n << 52)]
2683+
Vector128<float> ret = Vector128.Narrow(
2684+
(polyl.AsUInt64() + Vector128.ShiftLeft(nl, 52)).AsDouble(),
2685+
(polyu.AsUInt64() + Vector128.ShiftLeft(nu, 52)).AsDouble()
2686+
);
2687+
2688+
// Check if -103 < |x| < 88
2689+
if (Vector128.GreaterThanAny(x.AsUInt32() & Vector128.Create(V_MASK), Vector128.Create(V_ARG_MAX)))
2690+
{
2691+
// (x > V_EXPF_MAX) ? float.PositiveInfinity : x
2692+
Vector128<float> infinityMask = Vector128.GreaterThan(x, Vector128.Create(V_EXPF_MAX));
2693+
2694+
ret = Vector128.ConditionalSelect(
2695+
infinityMask,
2696+
Vector128.Create(float.PositiveInfinity),
2697+
ret
2698+
);
2699+
2700+
// (x < V_EXPF_MIN) ? 0 : x
2701+
ret = Vector128.AndNot(ret, Vector128.LessThan(x, Vector128.Create(V_EXPF_MIN)));
2702+
}
2703+
2704+
return ret;
2705+
}
2706+
2707+
public static Vector256<float> Invoke(Vector256<float> x)
2708+
{
2709+
// Convert x to double precision
2710+
(Vector256<double> xl, Vector256<double> xu) = Vector256.Widen(x);
2711+
2712+
// x * (64.0 / ln(2))
2713+
Vector256<double> v_tbl_ln2 = Vector256.Create(V_TBL_LN2);
2714+
2715+
Vector256<double> zl = xl * v_tbl_ln2;
2716+
Vector256<double> zu = xu * v_tbl_ln2;
2717+
2718+
Vector256<double> v_expf_huge = Vector256.Create(V_EXPF_HUGE);
2719+
2720+
Vector256<double> dnl = zl + v_expf_huge;
2721+
Vector256<double> dnu = zu + v_expf_huge;
2722+
2723+
// n = int (z)
2724+
Vector256<ulong> nl = dnl.AsUInt64();
2725+
Vector256<ulong> nu = dnu.AsUInt64();
2726+
2727+
// dn = double(n)
2728+
dnl -= v_expf_huge;
2729+
dnu -= v_expf_huge;
2730+
2731+
// r = z - dn
2732+
Vector256<double> c1 = Vector256.Create(C1);
2733+
Vector256<double> c2 = Vector256.Create(C2);
2734+
Vector256<double> c3 = Vector256.Create(C3);
2735+
Vector256<double> c4 = Vector256.Create(C4);
2736+
Vector256<double> c5 = Vector256.Create(C5);
2737+
Vector256<double> c6 = Vector256.Create(C6);
2738+
2739+
Vector256<double> rl = zl - dnl;
2740+
2741+
Vector256<double> rl2 = rl * rl;
2742+
Vector256<double> rl4 = rl2 * rl2;
2743+
2744+
Vector256<double> polyl = (c4 * rl + c3) * rl2
2745+
+ ((c6 * rl + c5) * rl4
2746+
+ (c2 * rl + c1));
2747+
2748+
2749+
Vector256<double> ru = zu - dnu;
2750+
2751+
Vector256<double> ru2 = ru * ru;
2752+
Vector256<double> ru4 = ru2 * ru2;
2753+
2754+
Vector256<double> polyu = (c4 * ru + c3) * ru2
2755+
+ ((c6 * ru + c5) * ru4
2756+
+ (c2 * ru + c1));
2757+
2758+
// result = (float)[poly + (n << 52)]
2759+
Vector256<float> ret = Vector256.Narrow(
2760+
(polyl.AsUInt64() + Vector256.ShiftLeft(nl, 52)).AsDouble(),
2761+
(polyu.AsUInt64() + Vector256.ShiftLeft(nu, 52)).AsDouble()
2762+
);
2763+
2764+
// Check if -103 < |x| < 88
2765+
if (Vector256.GreaterThanAny(x.AsUInt32() & Vector256.Create(V_MASK), Vector256.Create(V_ARG_MAX)))
2766+
{
2767+
// (x > V_EXPF_MAX) ? float.PositiveInfinity : x
2768+
Vector256<float> infinityMask = Vector256.GreaterThan(x, Vector256.Create(V_EXPF_MAX));
2769+
2770+
ret = Vector256.ConditionalSelect(
2771+
infinityMask,
2772+
Vector256.Create(float.PositiveInfinity),
2773+
ret
2774+
);
2775+
2776+
// (x < V_EXPF_MIN) ? 0 : x
2777+
ret = Vector256.AndNot(ret, Vector256.LessThan(x, Vector256.Create(V_EXPF_MIN)));
2778+
}
2779+
2780+
return ret;
2781+
}
2782+
2783+
#if NET8_0_OR_GREATER
2784+
public static Vector512<float> Invoke(Vector512<float> x)
2785+
{
2786+
// Convert x to double precision
2787+
(Vector512<double> xl, Vector512<double> xu) = Vector512.Widen(x);
2788+
2789+
// x * (64.0 / ln(2))
2790+
Vector512<double> v_tbl_ln2 = Vector512.Create(V_TBL_LN2);
2791+
2792+
Vector512<double> zl = xl * v_tbl_ln2;
2793+
Vector512<double> zu = xu * v_tbl_ln2;
2794+
2795+
Vector512<double> v_expf_huge = Vector512.Create(V_EXPF_HUGE);
2796+
2797+
Vector512<double> dnl = zl + v_expf_huge;
2798+
Vector512<double> dnu = zu + v_expf_huge;
2799+
2800+
// n = int (z)
2801+
Vector512<ulong> nl = dnl.AsUInt64();
2802+
Vector512<ulong> nu = dnu.AsUInt64();
2803+
2804+
// dn = double(n)
2805+
dnl -= v_expf_huge;
2806+
dnu -= v_expf_huge;
2807+
2808+
// r = z - dn
2809+
Vector512<double> c1 = Vector512.Create(C1);
2810+
Vector512<double> c2 = Vector512.Create(C2);
2811+
Vector512<double> c3 = Vector512.Create(C3);
2812+
Vector512<double> c4 = Vector512.Create(C4);
2813+
Vector512<double> c5 = Vector512.Create(C5);
2814+
Vector512<double> c6 = Vector512.Create(C6);
2815+
2816+
Vector512<double> rl = zl - dnl;
2817+
2818+
Vector512<double> rl2 = rl * rl;
2819+
Vector512<double> rl4 = rl2 * rl2;
2820+
2821+
Vector512<double> polyl = (c4 * rl + c3) * rl2
2822+
+ ((c6 * rl + c5) * rl4
2823+
+ (c2 * rl + c1));
2824+
2825+
2826+
Vector512<double> ru = zu - dnu;
2827+
2828+
Vector512<double> ru2 = ru * ru;
2829+
Vector512<double> ru4 = ru2 * ru2;
2830+
2831+
Vector512<double> polyu = (c4 * ru + c3) * ru2
2832+
+ ((c6 * ru + c5) * ru4
2833+
+ (c2 * ru + c1));
2834+
2835+
// result = (float)[poly + (n << 52)]
2836+
Vector512<float> ret = Vector512.Narrow(
2837+
(polyl.AsUInt64() + Vector512.ShiftLeft(nl, 52)).AsDouble(),
2838+
(polyu.AsUInt64() + Vector512.ShiftLeft(nu, 52)).AsDouble()
2839+
);
2840+
2841+
// Check if -103 < |x| < 88
2842+
if (Vector512.GreaterThanAny(x.AsUInt32() & Vector512.Create(V_MASK), Vector512.Create(V_ARG_MAX)))
2843+
{
2844+
// (x > V_EXPF_MAX) ? float.PositiveInfinity : x
2845+
Vector512<float> infinityMask = Vector512.GreaterThan(x, Vector512.Create(V_EXPF_MAX));
2846+
2847+
ret = Vector512.ConditionalSelect(
2848+
infinityMask,
2849+
Vector512.Create(float.PositiveInfinity),
2850+
ret
2851+
);
2852+
2853+
// (x < V_EXPF_MIN) ? 0 : x
2854+
ret = Vector512.AndNot(ret, Vector512.LessThan(x, Vector512.Create(V_EXPF_MIN)));
2855+
}
2856+
2857+
return ret;
2858+
}
2859+
#endif
2860+
}
2861+
25822862
private readonly struct LogOperator : IUnaryOperator
25832863
{
25842864
// This code is based on `vrs4_logf` from amd/aocl-libm-ose

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,19 @@ public Vector<float> Invoke(Vector<float> x, Vector<float> y)
923923
public Vector<float> Invoke(Vector<float> x) => Vector.Abs(x);
924924
}
925925

926+
private readonly struct ExpOperator : IUnaryOperator
927+
{
928+
public bool CanVectorize => false;
929+
930+
public float Invoke(float x) => MathF.Exp(x);
931+
932+
public Vector<float> Invoke(Vector<float> x)
933+
{
934+
// Vectorizing requires shift left support, which is .NET 7 or later
935+
throw new NotImplementedException();
936+
}
937+
}
938+
926939
private readonly struct LogOperator : IUnaryOperator
927940
{
928941
public bool CanVectorize => false;

0 commit comments

Comments
 (0)