DEV Community

Antidisestablishmentarianism
Antidisestablishmentarianism

Posted on

C# SIMD byte array compare

My byte array compare that I recently posted on stackoverflow.
https://stackoverflow.com/a/69280107/13843929

a1 and a2 are global byte arrays; it is done this way because of the testing program that the maintainers of the stackoverflow post have been using and I didn't want to maintain two copies of the code.

Quoting my own post on stackoverflow...
"This is similar to others, but the difference here is that there is no falling through to the next highest number of bytes I can check at once, e.g. if I have 63 bytes (in my SIMD example) I can check the equality of the first 32 bytes, and then the last 32 bytes, which is faster than checking 32 bytes, 16 bytes, 8 bytes, and so on. The first check you enter is the only check you will need to compare all of the bytes."

It is the fastest performer in my tests.

#requires System.Runtime.Intrinsics.X86
public unsafe bool SIMDNoFallThrough()    
{
    if (a1 == null || a2 == null)
        return false;

    int length0 = a1.Length;

    if (length0 != a2.Length) return false;

    fixed (byte* b00 = a1, b01 = a2)
    {
        byte* b0 = b00, b1 = b01, last0 = b0 + length0, last1 = b1 + length0, last32 = last0 - 31;

        if (length0 > 31)
        {
            while (b0 < last32)
            {
                if (Avx2.MoveMask(Avx2.CompareEqual(Avx.LoadVector256(b0), Avx.LoadVector256(b1))) != -1)
                    return false;
                b0 += 32;
                b1 += 32;
            }
            return Avx2.MoveMask(Avx2.CompareEqual(Avx.LoadVector256(last0 - 32), Avx.LoadVector256(last1 - 32))) == -1;
        }

        if (length0 > 15)
        {
            if (Sse2.MoveMask(Sse2.CompareEqual(Sse2.LoadVector128(b0), Sse2.LoadVector128(b1))) != 65535)
                return false;
            return Sse2.MoveMask(Sse2.CompareEqual(Sse2.LoadVector128(last0 - 16), Sse2.LoadVector128(last1 - 16))) == 65535;
        }

        if (length0 > 7)
        {
            if (*(ulong*)b0 != *(ulong*)b1)
                return false;
            return *(ulong*)(last0 - 8) == *(ulong*)(last1 - 8);
        }

        if (length0 > 3)
        {
            if (*(uint*)b0 != *(uint*)b1)
                return false;
            return *(uint*)(last0 - 4) == *(uint*)(last1 - 4);
        }

        if (length0 > 1)
        {
            if (*(ushort*)b0 != *(ushort*)b1)
                return false;
            return *(ushort*)(last0 - 2) == *(ushort*)(last1 - 2);
        }

        return *b0 == *b1;
    }
}
Enter fullscreen mode Exit fullscreen mode

Top comments (0)