Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = IsNaN(result);
if (nanMask != Vector512<T>.Zero)
{
return result.GetElement(IndexOfFirstMatch(nanMask));
return result.GetElement(Vector512.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -277,7 +277,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = ~Vector512.Equals(current, current);
if (nanMask != Vector512<T>.Zero)
{
return current.GetElement(IndexOfFirstMatch(nanMask));
return current.GetElement(Vector512.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -296,7 +296,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = ~Vector512.Equals(current, current);
if (nanMask != Vector512<T>.Zero)
{
return current.GetElement(IndexOfFirstMatch(nanMask));
return current.GetElement(Vector512.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -323,7 +323,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = ~Vector256.Equals(result, result);
if (nanMask != Vector256<T>.Zero)
{
return result.GetElement(IndexOfFirstMatch(nanMask));
return result.GetElement(Vector256.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -342,7 +342,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = ~Vector256.Equals(current, current);
if (nanMask != Vector256<T>.Zero)
{
return current.GetElement(IndexOfFirstMatch(nanMask));
return current.GetElement(Vector256.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -362,7 +362,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = ~Vector256.Equals(current, current);
if (nanMask != Vector256<T>.Zero)
{
return current.GetElement(IndexOfFirstMatch(nanMask));
return current.GetElement(Vector256.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -389,7 +389,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = IsNaN(result);
if (nanMask != Vector128<T>.Zero)
{
return result.GetElement(IndexOfFirstMatch(nanMask));
return result.GetElement(Vector128.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -408,7 +408,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = IsNaN(current);
if (nanMask != Vector128<T>.Zero)
{
return current.GetElement(IndexOfFirstMatch(nanMask));
return current.GetElement(Vector128.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand All @@ -427,7 +427,7 @@ private static T MinMaxCore<T, TMinMaxOperator>(ReadOnlySpan<T> x)
nanMask = IsNaN(current);
if (nanMask != Vector128<T>.Zero)
{
return current.GetElement(IndexOfFirstMatch(nanMask));
return current.GetElement(Vector128.IndexOfWhereAllBitsSet(nanMask));
}
}

Expand Down
55 changes: 25 additions & 30 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs
Original file line number Diff line number Diff line change
Expand Up @@ -530,16 +530,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
Vector128<byte> search = Vector128.Load(searchSpace + offset);

// Same method as below
uint matches = Vector128.Equals(Vector128<byte>.Zero, search).ExtractMostSignificantBits();
if (matches == 0)
Vector128<byte> cmp = Vector128.Equals(Vector128<byte>.Zero, search);
if (cmp == Vector128<byte>.Zero)
Comment on lines -533 to +534

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@EgorBo, so on x64 this is basically going to do:

                                        ; Approx 8 total cycles
    vxorps    xmm0, xmm0, xmm0          ; 0 cycles
    vpcmpeqb  xmm0, xmm0, xmm1          ; 1 cycle
    vptest    xmm0, xmm0                ; 7 cycles
    jz        SHORT NO_MATCH            ; fused

MATCH:                                  ; Approx 10 total cycles
    vpmovmskb eax, xmm0                 ; 5 cycles
    tzcnt     eax, eax                  ; 1 cycle
    mov       ecx, -1                   ; 1 cycle
    cmp       eax, 32                   ; 1 cycle
    cmove     eax, ecx                  ; 1 cycle
    add       eax, edx                  ; 1 cycle
    ret                                 ; return

NO_MATCH:
    ; ...

and on Arm64 (neoverse v2):

                                        ; Approx 7 total cycles
    cmeq    v16.16b, v0.16b, #0         ; 2 cycles
    umaxp   v17.4s, v16.4s, v16.4s      ; 2 cycles
    umov    x1, v17.d[0]                ; 2 cycles
    cmp     x1, #0                      ; 1 cycle
    b.eq    NO_MATCH                    ; branch

MATCH:                                  ; Approx 10 total cycles
    shrn    v16.8b, v16.8h, #4          ; 2 cycles
    umov    x1, v16.d[0]                ; 2 cycles
    rbit    x1, x1                      ; 1 cycle
    clz     x1, x1                      ; 1 cycle
    lsr     w1, w1, #2                  ; 1 cycles
    movn    w2, #0                      ; 1 cycle
    cmp     w1, #16                     ; 1 cycle
    csel    w1, w1, w2, ne              ; fused
    add     w0, w0, w1                  ; 1 cycle
    ret     lr                          ; return

NO_MATCH:
    ; ...

More ideally the JIT could recognize this general pattern and generate this instead for x64:

                                        ; Approx 7 total cycles
    vxorps    xmm0, xmm0, xmm0          ; 0 cycles
    vpcmpeqb  xmm0, xmm0, xmm1          ; 1 cycle
    vpmovmskb eax, xmm0                 ; 5 cycles
    cmp       eax, 0                    ; 1 cycle
    jz        SHORT NO_MATCH            ; fused

MATCH:                                  ; Approx 2 total cycle
    tzcnt     eax, eax                  ; 1 cycle
    add       eax, edx                  ; 1 cycle
    ret                                 ; return

NO_MATCH:
    ; ...

and this on Arm64:

                                        ; Approx 7 total cycles
    cmeq    v16.16b, v0.16b, #0         ; 2 cycles
    shrn    v16.8b, v16.8h, #4          ; 2 cycles
    umov    x1, v16.d[0]                ; 2 cycles
    cmp     w1, #0                      ; 1 cycle
    b.eq    NO_MATCH

MATCH:                                  ; Approx 4 total cycle
    rbit    x1, x1                      ; 1 cycle
    clz     x1, x1                      ; 1 cycle
    lsr     w1, w1, #2                  ; 1 cycles
    add     w0, w0, w1                  ; 1 cycle
    ret     lr                          ; returnmm

NO_MATCH:
    ; ...

This would make it significantly cheaper for both, but I think requires us to recognize the != Zero followed by an Count/IndexOf/LastIndexOf pattern. Specifically I think CSE would trivially handle this for Arm64, but on x64 we'd need to transform the != Zero in that case so CSE could kick in.

What are your thoughts on this?


The alternative is we setup the managed code to look like this:

int index = Vector128.IndexOf(search, 0);

if (index < 0)
{
    // Zero flags set so no matches
    offset += (nuint)Vector128<byte>.Count;
}
else
{
    // Find bitflag offset of first match and add to current offset
    return (int)(offset + (uint)Vector128.IndexOfFirstMatch(cmp));
}

Then we'd get this (roughly) on x64:

                                        ; Approx 11 total cycles
    vxorps    xmm0, xmm0, xmm0          ; 0 cycles
    vpcmpeqb  xmm0, xmm0, xmm1          ; 1 cycle
    vpmovmskb eax, xmm0                 ; 5 cycles
    tzcnt     eax, eax                  ; 1 cycle
    mov       ecx, -1                   ; 1 cycle
    cmp       eax, 32                   ; 1 cycle
    cmove     eax, ecx                  ; 1 cycle
    cmp       eax, 0                    ; 1 cycle
    jl        SHORT NO_MATCH            ; fused

MATCH:                                  ; Approx 1 total cycle
    add       eax, edx                  ; 1 cycle
    ret                                 ; return

NO_MATCH:
    ; ...

and this on Arm64:

                                        ; Approx 10 total cycles
    cmeq    v16.16b, v0.16b, #0         ; 2 cycles
    shrn    v16.8b, v16.8h, #4          ; 2 cycles
    umov    x1, v16.d[0]                ; 2 cycles
    rbit    x1, x1                      ; 1 cycle
    clz     x1, x1                      ; 1 cycle
    lsr     w1, w1, #2                  ; 1 cycles
    cmp     w1, #0                      ; 1 cycle
    b.ge    NO_MATCH

MATCH:                                  ; Approx 1 total cycle
    add     w0, w0, w1                  ; 1 cycle
    ret     lr                          ; returnmm

NO_MATCH:
    ; ...

This is a little less than half the cost on match on both platforms, but has slightly higher cost for the no match scenario.

But I expect this is also difficult to pattern match and handle to get it to generate what we want in the first scenario, right?

We should probably pick one and have that be the "recommended pattern" where we then have the JIT handle it for the ideal codegen. -- The "other" other thing we could do is use Vector128.AnyWhereAllBitsSet(mask) instead of mask != Vector128<T>.Zero, which might then be easier to optimize overall, but interested in your thoughts so we can work towards getting it optimized and have managed follow our desired shape.

{
// Zero flags set so no matches
offset += (nuint)Vector128<byte>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector128.IndexOfFirstMatch(cmp));
}
}

Expand All @@ -552,16 +552,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
Vector256<byte> search = Vector256.Load(searchSpace + offset);

// Same method as below
uint matches = Vector256.Equals(Vector256<byte>.Zero, search).ExtractMostSignificantBits();
if (matches == 0)
Vector256<byte> cmp = Vector256.Equals(Vector256<byte>.Zero, search);
if (cmp == Vector256<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector256<byte>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector256.IndexOfFirstMatch(cmp));
}
}
lengthToExamine = GetByteVector512SpanLength(offset, Length);
Expand All @@ -570,18 +570,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
do
{
Vector512<byte> search = Vector512.Load(searchSpace + offset);
ulong matches = Vector512.Equals(Vector512<byte>.Zero, search).ExtractMostSignificantBits();
// Note that MoveMask has converted the equal vector elements into a set of bit flags,
// So the bit position in 'matches' corresponds to the element offset.
if (matches == 0)
Vector512<byte> cmp = Vector512.Equals(Vector512<byte>.Zero, search);
if (cmp == Vector512<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector512<byte>.Count;
continue;
}

// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector512.IndexOfFirstMatch(cmp));
} while (lengthToExamine > offset);
}

Expand All @@ -591,16 +589,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
Vector256<byte> search = Vector256.Load(searchSpace + offset);

// Same method as above
uint matches = Vector256.Equals(Vector256<byte>.Zero, search).ExtractMostSignificantBits();
if (matches == 0)
Vector256<byte> cmp = Vector256.Equals(Vector256<byte>.Zero, search);
if (cmp == Vector256<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector256<byte>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector256.IndexOfFirstMatch(cmp));
}
}

Expand All @@ -610,16 +608,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
Vector128<byte> search = Vector128.Load(searchSpace + offset);

// Same method as above
uint matches = Vector128.Equals(Vector128<byte>.Zero, search).ExtractMostSignificantBits();
if (matches == 0)
Vector128<byte> cmp = Vector128.Equals(Vector128<byte>.Zero, search);
if (cmp == Vector128<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector128<byte>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector128.IndexOfFirstMatch(cmp));
}
}

Expand All @@ -643,16 +641,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
Vector128<byte> search = Vector128.Load(searchSpace + offset);

// Same method as below
uint matches = Vector128.Equals(Vector128<byte>.Zero, search).ExtractMostSignificantBits();
if (matches == 0)
Vector128<byte> cmp = Vector128.Equals(Vector128<byte>.Zero, search);
if (cmp == Vector128<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector128<byte>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector128.IndexOfFirstMatch(cmp));
}
}

Expand All @@ -662,18 +660,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
do
{
Vector256<byte> search = Vector256.Load(searchSpace + offset);
uint matches = Vector256.Equals(Vector256<byte>.Zero, search).ExtractMostSignificantBits();
// Note that MoveMask has converted the equal vector elements into a set of bit flags,
// So the bit position in 'matches' corresponds to the element offset.
if (matches == 0)
Vector256<byte> cmp = Vector256.Equals(Vector256<byte>.Zero, search);
if (cmp == Vector256<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector256<byte>.Count;
continue;
}

// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector256.IndexOfFirstMatch(cmp));
} while (lengthToExamine > offset);
}

Expand All @@ -683,16 +679,16 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
Vector128<byte> search = Vector128.Load(searchSpace + offset);

// Same method as above
uint matches = Vector128.Equals(Vector128<byte>.Zero, search).ExtractMostSignificantBits();
if (matches == 0)
Vector128<byte> cmp = Vector128.Equals(Vector128<byte>.Zero, search);
if (cmp == Vector128<byte>.Zero)
{
// Zero flags set so no matches
offset += (nuint)Vector128<byte>.Count;
}
else
{
// Find bitflag offset of first match and add to current offset
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector128.IndexOfFirstMatch(cmp));
}
}

Expand Down Expand Up @@ -723,8 +719,7 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
}

// Find bitflag offset of first match and add to current offset
uint matches = compareResult.ExtractMostSignificantBits();
return (int)(offset + (uint)BitOperations.TrailingZeroCount(matches));
return (int)(offset + (uint)Vector128.IndexOfFirstMatch(compareResult));
}

if (offset < (nuint)(uint)Length)
Expand Down
30 changes: 9 additions & 21 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3713,49 +3713,37 @@ private static int LastIndexOfAnyValueType<TValue, TNegator>(ref TValue searchSp
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int ComputeFirstIndex<T>(ref T searchSpace, ref T current, Vector128<T> equals) where T : struct
{
uint notEqualsElements = equals.ExtractMostSignificantBits();
int index = BitOperations.TrailingZeroCount(notEqualsElements);
return index + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref current) / (nuint)sizeof(T));
return Vector128.IndexOfFirstMatch(equals) + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref current) / (nuint)sizeof(T));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int ComputeFirstIndex<T>(ref T searchSpace, ref T current, Vector256<T> equals) where T : struct
{
uint notEqualsElements = equals.ExtractMostSignificantBits();
int index = BitOperations.TrailingZeroCount(notEqualsElements);
return index + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref current) / (nuint)sizeof(T));
return Vector256.IndexOfFirstMatch(equals) + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref current) / (nuint)sizeof(T));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int ComputeFirstIndex<T>(ref T searchSpace, ref T current, Vector512<T> equals) where T : struct
{
ulong notEqualsElements = equals.ExtractMostSignificantBits();
int index = BitOperations.TrailingZeroCount(notEqualsElements);
return index + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref current) / (nuint)sizeof(T));
return Vector512.IndexOfFirstMatch(equals) + (int)((nuint)Unsafe.ByteOffset(ref searchSpace, ref current) / (nuint)sizeof(T));
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int ComputeLastIndex<T>(nint offset, Vector128<T> equals) where T : struct
{
uint notEqualsElements = equals.ExtractMostSignificantBits();
int index = 31 - BitOperations.LeadingZeroCount(notEqualsElements); // 31 = 32 (bits in Int32) - 1 (indexing from zero)
return (int)offset + index;
return (int)offset + Vector128.IndexOfLastMatch(equals);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int ComputeLastIndex<T>(nint offset, Vector256<T> equals) where T : struct
{
uint notEqualsElements = equals.ExtractMostSignificantBits();
int index = 31 - BitOperations.LeadingZeroCount(notEqualsElements); // 31 = 32 (bits in Int32) - 1 (indexing from zero)
return (int)offset + index;
return (int)offset + Vector256.IndexOfLastMatch(equals);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int ComputeLastIndex<T>(nint offset, Vector512<T> equals) where T : struct
{
ulong notEqualsElements = equals.ExtractMostSignificantBits();
int index = 63 - BitOperations.LeadingZeroCount(notEqualsElements); // 31 = 32 (bits in Int32) - 1 (indexing from zero)
return (int)offset + index;
return (int)offset + Vector512.IndexOfLastMatch(equals);
}

internal interface INegator<T> where T : struct
Expand Down Expand Up @@ -4176,7 +4164,7 @@ public static unsafe int CountValueType<T>(ref T current, T value, int length) w
ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector512<T>.Count);
while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEnd))
{
count += BitOperations.PopCount(Vector512.Equals(Vector512.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits());
count += Vector512.CountMatches(Vector512.Equals(Vector512.LoadUnsafe(ref current), targetVector));
current = ref Unsafe.Add(ref current, Vector512<T>.Count);
}

Expand All @@ -4191,7 +4179,7 @@ public static unsafe int CountValueType<T>(ref T current, T value, int length) w
ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector256<T>.Count);
while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEnd))
{
count += BitOperations.PopCount(Vector256.Equals(Vector256.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits());
count += Vector256.CountMatches(Vector256.Equals(Vector256.LoadUnsafe(ref current), targetVector));
current = ref Unsafe.Add(ref current, Vector256<T>.Count);
}

Expand All @@ -4206,7 +4194,7 @@ public static unsafe int CountValueType<T>(ref T current, T value, int length) w
ref T oneVectorAwayFromEnd = ref Unsafe.Subtract(ref end, Vector128<T>.Count);
while (Unsafe.IsAddressLessThan(ref current, ref oneVectorAwayFromEnd))
{
count += BitOperations.PopCount(Vector128.Equals(Vector128.LoadUnsafe(ref current), targetVector).ExtractMostSignificantBits());
count += Vector128.CountMatches(Vector128.Equals(Vector128.LoadUnsafe(ref current), targetVector));
current = ref Unsafe.Add(ref current, Vector128<T>.Count);
}

Expand Down
Loading