2

Given the following input bytes:

var vBytes = new Vector<byte>(new byte[] {72, 101, 55, 08, 108, 111, 55, 87, 111, 114, 108, 55, 100, 55, 55, 20});

And the given mask:

var mask = new Vector<byte>(55);

How can I find the count of byte 55 in the input array?

I have tried xoring the vBytes with the mask:

var xored = Vector.Xor(mask, vBytes);

which gives:

<127, 82, 0, 91, 91, 88, 0, 96, 88, 69, 91, 0, 83, 0, 0, 35>

But don't know how I can get the count from that.

For the sake of simplicity let's assume that the input byte length is always equal to the size of Vector<byte>.Count.

8
  • You mean without simple for loop? Commented Mar 29, 2018 at 10:03
  • 1
    FYI - Vector.Equals(vBytes,mask) is probably more intuitive than xor - it returns a vector of 255s/0s. How to count them, though... Commented Mar 29, 2018 at 10:04
  • @MarcGravell Awesome! I got it!, will update with the answer. Commented Mar 29, 2018 at 10:07
  • Vector.Dot(Vector.Negate(Vector.Equals(vBytes, new Vector<byte>(55))), new Vector<byte>(1)) would do it. However, I have no experience with SIMD and I don't know if this is a reasonable approach. Commented Mar 29, 2018 at 10:07
  • 1
    @MarcGravell: yup, packed byte compare, then use psadbw to horizontal-sum those results into 64-bit elements. Commented Mar 30, 2018 at 4:59

4 Answers 4

4

(AVX2 C intrinsics implementation of the below idea, in case a concrete example helps: How to count character occurrences using SIMD)

In asm, you want pcmpeqb to produce a vector of 0 or 0xFF. Treated as signed integers, that's 0/-1.

Then use the compare-result as integers values with psubb to add 0 / 1 to the counter for that element. (Subtract -1 = add +1)

That can overflows after 256 iterations, so sometime before that, use psadbw against _mm_setzero_si128() to horizontally sum those unsigned bytes (without overlow) into 64-bit integers (one 64-bit integer per group of 8 bytes). Then paddq to accumulate 64-bit totals.

Accumulating before you overflow can be done with a nested loop, or just at the end of a regular unrolled loop. psadbw is fast (because it's a key building block for video encoding motion-search), so it's not bad to just accumulate every 4 compares, or even every 1 and skip the psubb.

See Agner Fog's optimization guides for more details on x86. According to his instruction tables, psadbw xmm / vpsadbw ymm runs at 1 vector per clock cycle on Skylake, with 3 cycle latency. (Only 1 uop of front-end bandwidth.) All the instructions mentioned above are also single-uop, and run on more than one port (so don't necessarily conflict with each other for throughput). Their 128-bit versions only require SSE2.


If you really only have one vector at a time to count, and aren't looping over memory, then probably pcmpeqb / psadbw / pshufd (copy high half to low) / paddd / movd eax, xmm0 gives you 255 * number of matches in an integer register. One extra vector instruction (like subtract from zero, or AND with 1, or pabsb (absolute value) would remove the x255 scale factor.


IDK how to write that in C# SIMD, but you definitely do not want a dot-product! Unpack and convert to FP would be about 4x slower than the above, just from the fact that a fixed-width vector holds 4x more bytes than floats, and dpps (_mm_dp_ps) is not fast. 4 uops, and one per 1.5 cycle throughput on Skylake. If you do have to horizontal-sum something other than unsigned bytes, see Fastest way to do horizontal SSE vector sum (or other reduction) (my answer also include integer).

Or if Vector.Dot uses pmaddubsw / pmaddwd for integer vectors, then that might not be as bad, but doing a multi-step horizontal sum for each vector of compare results is just bad compared to psadbw, or especially to byte accumulators that you only horizontal sum occasionally.

Or if C# optimizes out any actual multiplying with a constant vector of 1. Anyway, the first part of this answer is the code you want the CPU to be running. Make that happen however you like using whatever source code gets it to happen.

Sign up to request clarification or add additional context in comments.

Comments

4

I know that I'm super late to the party, but so far none of the answers here actually provide a full solution. Here's my best attempt at one, derived from this Gist and the DotNet source code. All credit goes to the DotNet team and community members here (especially @Peter Cordes).

Usage:

var bytes = Encoding.ASCII.GetBytes("The quick brown fox jumps over the lazy dog.");
var byteCount = bytes.OccurrencesOf(32);

var chars = "The quick brown fox jumps over the lazy dog.";
var charCount = chars.OccurrencesOf(' ');

Code:

public static class VectorExtensions
{
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static nuint GetByteVector128SpanLength(nuint offset, int length) =>
        ((nuint)(uint)((length - (int)offset) & ~(Vector128<byte>.Count - 1)));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static nuint GetByteVector256SpanLength(nuint offset, int length) =>
        ((nuint)(uint)((length - (int)offset) & ~(Vector256<byte>.Count - 1)));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static nint GetCharVector128SpanLength(nint offset, nint length) =>
        ((length - offset) & ~(Vector128<ushort>.Count - 1));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static nint GetCharVector256SpanLength(nint offset, nint length) =>
        ((length - offset) & ~(Vector256<ushort>.Count - 1));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static Vector128<byte> LoadVector128(ref byte start, nuint offset) =>
        Unsafe.ReadUnaligned<Vector128<byte>>(ref Unsafe.AddByteOffset(ref start, offset));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static Vector256<byte> LoadVector256(ref byte start, nuint offset) =>
        Unsafe.ReadUnaligned<Vector256<byte>>(ref Unsafe.AddByteOffset(ref start, offset));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static Vector128<ushort> LoadVector128(ref char start, nint offset) =>
        Unsafe.ReadUnaligned<Vector128<ushort>>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref start, offset)));
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static Vector256<ushort> LoadVector256(ref char start, nint offset) =>
        Unsafe.ReadUnaligned<Vector256<ushort>>(ref Unsafe.As<char, byte>(ref Unsafe.Add(ref start, offset)));
    [MethodImpl(MethodImplOptions.AggressiveOptimization)]
    private static unsafe int OccurrencesOf(ref byte searchSpace, byte value, int length) {
        var lengthToExamine = ((nuint)length);
        var offset = ((nuint)0);
        var result = 0L;

        if (Sse2.IsSupported || Avx2.IsSupported) {
            if (31 < length) {
                lengthToExamine = UnalignedCountVector128(ref searchSpace);
            }
        }

    SequentialScan:
        while (7 < lengthToExamine) {
            ref byte current = ref Unsafe.AddByteOffset(ref searchSpace, offset);

            if (value == current) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 1)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 2)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 3)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 4)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 5)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 6)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 7)) {
                ++result;
            }

            lengthToExamine -= 8;
            offset += 8;
        }

        while (3 < lengthToExamine) {
            ref byte current = ref Unsafe.AddByteOffset(ref searchSpace, offset);

            if (value == current) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 1)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 2)) {
                ++result;
            }
            if (value == Unsafe.AddByteOffset(ref current, 3)) {
                ++result;
            }

            lengthToExamine -= 4;
            offset += 4;
        }

        while (0 < lengthToExamine) {
            if (value == Unsafe.AddByteOffset(ref searchSpace, offset)) {
                ++result;
            }

            --lengthToExamine;
            ++offset;
        }

        if (offset < ((nuint)(uint)length)) {
            if (Avx2.IsSupported) {
                if (0 != (((nuint)(uint)Unsafe.AsPointer(ref searchSpace) + offset) & (nuint)(Vector256<byte>.Count - 1))) {
                    var sum = Sse2.SumAbsoluteDifferences(Sse2.Subtract(Vector128<byte>.Zero, Sse2.CompareEqual(Vector128.Create(value), LoadVector128(ref searchSpace, offset))).AsByte(), Vector128<byte>.Zero).AsInt64();

                    offset += 16;
                    result += (sum.GetElement(0) + sum.GetElement(1));
                }

                lengthToExamine = GetByteVector256SpanLength(offset, length);

                var searchMask = Vector256.Create(value);

                if (127 < lengthToExamine) {
                    var sum = Vector256<long>.Zero;

                    do {
                        var accumulator0 = Vector256<byte>.Zero;
                        var accumulator1 = Vector256<byte>.Zero;
                        var accumulator2 = Vector256<byte>.Zero;
                        var accumulator3 = Vector256<byte>.Zero;
                        var loopIndex = ((nuint)0);
                        var loopLimit = Math.Min(255, (lengthToExamine / 128));

                        do {
                            accumulator0 = Avx2.Subtract(accumulator0, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, offset)));
                            accumulator1 = Avx2.Subtract(accumulator1, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 32))));
                            accumulator2 = Avx2.Subtract(accumulator2, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 64))));
                            accumulator3 = Avx2.Subtract(accumulator3, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 96))));
                            loopIndex++;
                            offset += 128;
                        } while (loopIndex < loopLimit);

                        lengthToExamine -= (128 * loopLimit);
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator0.AsByte(), Vector256<byte>.Zero).AsInt64());
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator1.AsByte(), Vector256<byte>.Zero).AsInt64());
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator2.AsByte(), Vector256<byte>.Zero).AsInt64());
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator3.AsByte(), Vector256<byte>.Zero).AsInt64());
                    } while (127 < lengthToExamine);

                    var sumX = Avx2.ExtractVector128(sum, 0);
                    var sumY = Avx2.ExtractVector128(sum, 1);
                    var sumZ = Sse2.Add(sumX, sumY);

                    result += (sumZ.GetElement(0) + sumZ.GetElement(1));
                }

                if (31 < lengthToExamine) {
                    var sum = Vector256<long>.Zero;

                    do {
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(Avx2.Subtract(Vector256<byte>.Zero, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, offset))).AsByte(), Vector256<byte>.Zero).AsInt64());
                        lengthToExamine -= 32;
                        offset += 32;
                    } while (31 < lengthToExamine);

                    var sumX = Avx2.ExtractVector128(sum, 0);
                    var sumY = Avx2.ExtractVector128(sum, 1);
                    var sumZ = Sse2.Add(sumX, sumY);

                    result += (sumZ.GetElement(0) + sumZ.GetElement(1));
                }

                if (offset < ((nuint)(uint)length)) {
                    lengthToExamine = (((nuint)(uint)length) - offset);

                    goto SequentialScan;
                }
            }
            else if (Sse2.IsSupported) {
                lengthToExamine = GetByteVector128SpanLength(offset, length);

                var searchMask = Vector128.Create(value);

                if (63 < lengthToExamine) {
                    var sum = Vector128<long>.Zero;

                    do {
                        var accumulator0 = Vector128<byte>.Zero;
                        var accumulator1 = Vector128<byte>.Zero;
                        var accumulator2 = Vector128<byte>.Zero;
                        var accumulator3 = Vector128<byte>.Zero;
                        var loopIndex = ((nuint)0);
                        var loopLimit = Math.Min(255, (lengthToExamine / 64));

                        do {
                            accumulator0 = Sse2.Subtract(accumulator0, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, offset)));
                            accumulator1 = Sse2.Subtract(accumulator1, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 16))));
                            accumulator2 = Sse2.Subtract(accumulator2, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 32))));
                            accumulator3 = Sse2.Subtract(accumulator3, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 48))));
                            loopIndex++;
                            offset += 64;
                        } while (loopIndex < loopLimit);

                        lengthToExamine -= (64 * loopLimit);
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator0.AsByte(), Vector128<byte>.Zero).AsInt64());
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator1.AsByte(), Vector128<byte>.Zero).AsInt64());
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator2.AsByte(), Vector128<byte>.Zero).AsInt64());
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator3.AsByte(), Vector128<byte>.Zero).AsInt64());
                    } while (63 < lengthToExamine);

                    result += (sum.GetElement(0) + sum.GetElement(1));
                }

                if (15 < lengthToExamine) {
                    var sum = Vector128<long>.Zero;

                    do {
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(Sse2.Subtract(Vector128<byte>.Zero, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, offset))).AsByte(), Vector128<byte>.Zero).AsInt64());
                        lengthToExamine -= 16;
                        offset += 16;
                    } while (15 < lengthToExamine);

                    result += (sum.GetElement(0) + sum.GetElement(1));
                }

                if (offset < ((nuint)(uint)length)) {
                    lengthToExamine = (((nuint)(uint)length) - offset);

                    goto SequentialScan;
                }
            }
        }

        return ((int)result);
    }
    [MethodImpl(MethodImplOptions.AggressiveOptimization)]
    private static unsafe int OccurrencesOf(ref char searchSpace, char value, int length) {
        var lengthToExamine = ((nint)length);
        var offset = ((nint)0);
        var result = 0L;

        if (0 != ((int)Unsafe.AsPointer(ref searchSpace) & 1)) { }
        else if (Sse2.IsSupported || Avx2.IsSupported) {
            if (15 < length) {
                lengthToExamine = UnalignedCountVector128(ref searchSpace);
            }
        }

    SequentialScan:
        while (3 < lengthToExamine) {
            ref char current = ref Unsafe.Add(ref searchSpace, offset);

            if (value == current) {
                ++result;
            }
            if (value == Unsafe.Add(ref current, 1)) {
                ++result;
            }
            if (value == Unsafe.Add(ref current, 2)) {
                ++result;
            }
            if (value == Unsafe.Add(ref current, 3)) {
                ++result;
            }

            lengthToExamine -= 4;
            offset += 4;
        }

        while (0 < lengthToExamine) {
            if (value == Unsafe.Add(ref searchSpace, offset)) {
                ++result;
            }

            --lengthToExamine;
            ++offset;
        }

        if (offset < length) {
            if (Avx2.IsSupported) {
                if (0 != (((nint)Unsafe.AsPointer(ref Unsafe.Add(ref searchSpace, offset))) & (Vector256<byte>.Count - 1))) {
                    var sum = Sse2.SumAbsoluteDifferences(Sse2.Subtract(Vector128<ushort>.Zero, Sse2.CompareEqual(Vector128.Create(value), LoadVector128(ref searchSpace, offset))).AsByte(), Vector128<byte>.Zero).AsInt64();

                    offset += 8;
                    result += (sum.GetElement(0) + sum.GetElement(1));
                }

                lengthToExamine = GetCharVector256SpanLength(offset, length);

                var searchMask = Vector256.Create(value);

                if (63 < lengthToExamine) {
                    var sum = Vector256<long>.Zero;

                    do {
                        var accumulator0 = Vector256<ushort>.Zero;
                        var accumulator1 = Vector256<ushort>.Zero;
                        var accumulator2 = Vector256<ushort>.Zero;
                        var accumulator3 = Vector256<ushort>.Zero;
                        var loopIndex = 0;
                        var loopLimit = Math.Min(255, (lengthToExamine / 64));

                        do {
                            accumulator0 = Avx2.Subtract(accumulator0, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, offset)));
                            accumulator1 = Avx2.Subtract(accumulator1, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 16))));
                            accumulator2 = Avx2.Subtract(accumulator2, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 32))));
                            accumulator3 = Avx2.Subtract(accumulator3, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, (offset + 48))));
                            loopIndex++;
                            offset += 64;
                        } while (loopIndex < loopLimit);

                        lengthToExamine -= (64 * loopLimit);
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator0.AsByte(), Vector256<byte>.Zero).AsInt64());
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator1.AsByte(), Vector256<byte>.Zero).AsInt64());
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator2.AsByte(), Vector256<byte>.Zero).AsInt64());
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(accumulator3.AsByte(), Vector256<byte>.Zero).AsInt64());
                    } while (63 < lengthToExamine);

                    var sumX = Avx2.ExtractVector128(sum, 0);
                    var sumY = Avx2.ExtractVector128(sum, 1);
                    var sumZ = Sse2.Add(sumX, sumY);

                    result += (sumZ.GetElement(0) + sumZ.GetElement(1));
                }

                if (15 < lengthToExamine) {
                    var sum = Vector256<long>.Zero;

                    do {
                        sum = Avx2.Add(sum, Avx2.SumAbsoluteDifferences(Avx2.Subtract(Vector256<ushort>.Zero, Avx2.CompareEqual(searchMask, LoadVector256(ref searchSpace, offset))).AsByte(), Vector256<byte>.Zero).AsInt64());
                        lengthToExamine -= 16;
                        offset += 16;
                    } while (15 < lengthToExamine);

                    var sumX = Avx2.ExtractVector128(sum, 0);
                    var sumY = Avx2.ExtractVector128(sum, 1);
                    var sumZ = Sse2.Add(sumX, sumY);

                    result += (sumZ.GetElement(0) + sumZ.GetElement(1));
                }

                if (offset < length) {
                    lengthToExamine = (length - offset);

                    goto SequentialScan;
                }
            }
            else if (Sse2.IsSupported) {
                lengthToExamine = GetCharVector128SpanLength(offset, length);

                var searchMask = Vector128.Create(value);

                if (31 < lengthToExamine) {
                    var sum = Vector128<long>.Zero;

                    do {
                        var accumulator0 = Vector128<ushort>.Zero;
                        var accumulator1 = Vector128<ushort>.Zero;
                        var accumulator2 = Vector128<ushort>.Zero;
                        var accumulator3 = Vector128<ushort>.Zero;
                        var loopIndex = 0;
                        var loopLimit = Math.Min(255, (lengthToExamine / 32));

                        do {
                            accumulator0 = Sse2.Subtract(accumulator0, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, offset)));
                            accumulator1 = Sse2.Subtract(accumulator1, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 8))));
                            accumulator2 = Sse2.Subtract(accumulator2, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 16))));
                            accumulator3 = Sse2.Subtract(accumulator3, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, (offset + 24))));
                            loopIndex++;
                            offset += 32;
                        } while (loopIndex < loopLimit);

                        lengthToExamine -= (32 * loopLimit);
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator0.AsByte(), Vector128<byte>.Zero).AsInt64());
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator1.AsByte(), Vector128<byte>.Zero).AsInt64());
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator2.AsByte(), Vector128<byte>.Zero).AsInt64());
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(accumulator3.AsByte(), Vector128<byte>.Zero).AsInt64());
                    } while (31 < lengthToExamine);

                    result += (sum.GetElement(0) + sum.GetElement(1));
                }

                if (7 < lengthToExamine) {
                    var sum = Vector128<long>.Zero;

                    do {
                        sum = Sse2.Add(sum, Sse2.SumAbsoluteDifferences(Sse2.Subtract(Vector128<ushort>.Zero, Sse2.CompareEqual(searchMask, LoadVector128(ref searchSpace, offset))).AsByte(), Vector128<byte>.Zero).AsInt64());
                        lengthToExamine -= 8;
                        offset += 8;
                    } while (7 < lengthToExamine);

                    result += (sum.GetElement(0) + sum.GetElement(1));
                }

                if (offset < length) {
                    lengthToExamine = (length - offset);

                    goto SequentialScan;
                }
            }
        }

        return ((int)result);
    }
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static unsafe nuint UnalignedCountVector128(ref byte searchSpace) {
        nint unaligned = ((nint)Unsafe.AsPointer(ref searchSpace) & (Vector128<byte>.Count - 1));

        return ((nuint)(uint)((Vector128<byte>.Count - unaligned) & (Vector128<byte>.Count - 1)));
    }
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    private static unsafe nint UnalignedCountVector128(ref char searchSpace) {
        const int ElementsPerByte = (sizeof(ushort) / sizeof(byte));

        return ((nint)(uint)(-(int)Unsafe.AsPointer(ref searchSpace) / ElementsPerByte) & (Vector128<ushort>.Count - 1));
    }

    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static int OccurrencesOf(this ReadOnlySpan<byte> span, byte value) =>
        OccurrencesOf(
            length: span.Length,
            searchSpace: ref MemoryMarshal.GetReference(span),
            value: value
        );
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static int OccurrencesOf(this Span<byte> span, byte value) =>
        ((ReadOnlySpan<byte>)span).OccurrencesOf(value);
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static int OccurrencesOf(this ReadOnlySpan<char> span, char value) =>
        OccurrencesOf(
            length: span.Length,
            searchSpace: ref MemoryMarshal.GetReference(span),
            value: value
        );
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static int OccurrencesOf(this Span<char> span, char value) =>
        ((ReadOnlySpan<char>)span).OccurrencesOf(value);
}

21 Comments

You can avoid the Avx2.Subtract(Vector256<ushort>.Zero, accumulator0) cleanup before PSADBW by using sub instead of add in the first place. i.e. accumulator -= cmp() because cmp results are -1 or 0. Also, do left + right and then reduce that, instead of separately extracting all 4 elements.
Cheers. Thanks for writing up an actual C# implementation; I don't really know C# beyond seeing it on SO, so I wasn't going to attempt this. How to count character occurrences using SIMD uses nested loops to deal with overflow, might want to take a look at exactly how it's implemented.
Also, in the outer loop, you don't need to reduce to scalar, just to one SIMD vector var sum would be fine. The hsum of that to a scalar integer can sink out of the outer loop. (Make sure you're using SIMD addition with 64-bit or at least 32-bit element size for all the accumulation of psadbw results. I guess SumAbsoluteDifferences() returns a vector<uint64_t> or whatever C# calls it, which implies the element type)
Also, consider what happens for an input that's say 31 bytes long (or 63x uint16 since you're doing strings not bytes like the question asked). Or n*64 + 31. That's a lot of scalar iterations. That's the downside to unrolling: unless you also provide a not-unrolled vector loop, you make the worst case (including small cases) spend more time in the slow scalar code. If you want to tune this for short to medium strings, you might provide a loop that does one SSE2 vector per iteration, leaving at most 7 leftover elements.
Oops, meant to write min(..., 255) to clamp to 255 as an upper limit on how many inner iterations to do, but do less if you're close to the end of the buffer. I frequently mix up min vs. max for setting a maximum on some value if I don't stop and think about it. :/
|
2

Here a fast SSE2 implementation in C:

size_t memcount_sse2(const void *s, int c, size_t n) {
   __m128i cv = _mm_set1_epi8(c), sum = _mm_setzero_si128(), acr0,acr1,acr2,acr3;
    const char *p,*pe;                                                                         
    for(p = s; p != (char *)s+(n- (n % (252*16)));) { 
      for(acr0 = acr1 = acr2 = acr3 = _mm_setzero_si128(),pe = p+252*16; p != pe; p += 64) { 
        acr0 = _mm_add_epi8(acr0, _mm_cmpeq_epi8(cv, _mm_loadu_si128((const __m128i *)p))); 
        acr1 = _mm_add_epi8(acr1, _mm_cmpeq_epi8(cv, _mm_loadu_si128((const __m128i *)(p+16)))); 
        acr2 = _mm_add_epi8(acr2, _mm_cmpeq_epi8(cv, _mm_loadu_si128((const __m128i *)(p+32)))); 
        acr3 = _mm_add_epi8(acr3, _mm_cmpeq_epi8(cv, _mm_loadu_si128((const __m128i *)(p+48))));
        __builtin_prefetch(p+1024);
      }
      sum = _mm_add_epi64(sum, _mm_sad_epu8(_mm_sub_epi8(_mm_setzero_si128(), acr0), _mm_setzero_si128()));
      sum = _mm_add_epi64(sum, _mm_sad_epu8(_mm_sub_epi8(_mm_setzero_si128(), acr1), _mm_setzero_si128()));
      sum = _mm_add_epi64(sum, _mm_sad_epu8(_mm_sub_epi8(_mm_setzero_si128(), acr2), _mm_setzero_si128()));
      sum = _mm_add_epi64(sum, _mm_sad_epu8(_mm_sub_epi8(_mm_setzero_si128(), acr3), _mm_setzero_si128()));
    }

    // may require SSE4, rewrite this part for actual SSE2.
    size_t count = _mm_extract_epi64(sum, 0) + _mm_extract_epi64(sum, 1);

    // scalar cleanup.  Could be optimized.
    while(p != (char *)s + n) count += *p++ == c;
    return count;
}

and see: https://gist.github.com/powturbo for and avx2 implementation.

6 Comments

With some compilers, _mm_extract_epi64(sum, 1) will only compile with SSE4.1. You could use _mm_sub_epi8 inside the inner loop to avoid needing to negate the accumulators before psadbw. acr0 -= -1 is the same as acr0 += 1.
How much speedup does that prefetch give? On IvyBridge and later, with hardware next-page prefetch, it shouldn't make much difference.
Also, you could do much better for small uneven-size buffers with a cleanup loop that went 1 vector at a time, then maybe 1 movq, instead of up to 63 one-byte-at-a-time iterations. Or maybe use a load that goes right up to the end of the buffer and mask off the overlapping bytes that you'd double-count. (e.g. load a mask from a sliding window of ...,0,0,0,-1,-1,-1,-1,..., like this stackoverflow.com/questions/34306933/…)
@PeterCordes thanks for your suggestions. Prefetch speed up is ~10% on i2600k and large buffers.
Interesting. If I get around to it, I'll test on Skylake. (Probably much lower speedup because of next-page prefetching.) I don't have an IvB system, but IvB apparently has some kind of major throughput bottleneck for SW prefetch instructions.
|
1

Thanks to Marc Gravell for his tip, the following works:

var areEqual = Vector.Equals(vBytes, mask);
var negation = Vector.Negate(areEqual);
var count = Vector.Dot(negation, Vector<byte>.One);

Marc has a blog post with more info on the subject.

2 Comments

nice; easier to read as var count = Vector.Dot(-Vector.Equals(vBytes, mask), Vector<byte>.One);, but: like it; note: you need to be really careful how you "load" the vectors for SIMD; if you aren't careful, you can lose all the benefit due to load overhead. Span<T> is a great way to load them - raw arrays: not usually so much
Agreed, this was a contrived example to get the core of it working, in production it will be further optimized. Thanks for the light-bulb moment though!

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.