#include <immintrin.h>
#include <stdint.h>
#include <stdio.h>

const uint8_t lut[256] = {
    255, 128,  85,  64,  51,  43,  37,  32,  28,  26,  23,  21,  20,  18,  17,  16,
     15,  14,  13,  13,  12,  12,  11,  11,  10,  10,   9,   9,   9,   9,   8,   8,
      8,   8,   7,   7,   7,   7,   7,   6,   6,   6,   6,   6,   6,   6,   5,   5,
      5,   5,   5,   5,   5,   5,   5,   5,   4,   4,   4,   4,   4,   4,   4,   4,
      4,   4,   4,   4,   4,   4,   4,   4,   4,   3,   3,   3,   3,   3,   3,   3,
      3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,
      3,   3,   3,   3,   3,   3,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
      2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
      2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
      2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,
      2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   1,   1,   1,   1,   1,   1,
      1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
      1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
      1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
      1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
      1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   0,
};

int main(int argc, char *argv[]) {
    int ret = 0;

    // test xmm rcpps
    for (int z = 0; z < 4095; z++) {
        const __m128 f256 = _mm_set1_ps(256.0f);
        union {
            int32_t i[4];
            __m128i m;
        } buf;

        __m128 zp1 = _mm_set1_ps(z + 1);
        __m128 recip = _mm_mul_ps(_mm_rcp_ps(zp1), f256);
        __m128i clip = _mm_srli_epi32(_mm_cmpgt_epi32(_mm_castps_si128(f256), _mm_castps_si128(zp1)), 24);
        __m128i res = _mm_min_epi16(_mm_cvtps_epi32(recip), clip);
        _mm_store_si128(&buf.m, res);

        int a = lut[z > 255 ? 255 : z];
        for (int i = 0; i < 4; i++) {
            int b = buf.i[i];
            if (a != b) {
                fprintf(stderr, "xmm mismatch z=%d a=%d b=%d\n", z, a, b);
                ret = 1;
                break;
            }
        }
    }

    // test ymm rcpps
    for (int z = 0; z < 4095; z++) {
        const __m256 f256 = _mm256_set1_ps(256.0f);
        union {
            int32_t i[8];
            __m256i m;
        } buf;

        __m256 zp1 = _mm256_set1_ps(z + 1);
        __m256 recip = _mm256_mul_ps(_mm256_rcp_ps(zp1), f256);
        __m256i clip = _mm256_srli_epi32(_mm256_cmpgt_epi32(_mm256_castps_si256(f256), _mm256_castps_si256(zp1)), 24);
        __m256i res = _mm256_min_epi16(_mm256_cvtps_epi32(recip), clip);
        _mm256_store_si256(&buf.m, res);

        int a = lut[z > 255 ? 255 : z];
        for (int i = 0; i < 8; i++) {
            int b = buf.i[i];
            if (a != b) {
                fprintf(stderr, "ymm mismatch z=%d a=%d b=%d\n", z, a, b);
                ret = 1;
                break;
            }
        }
    }

    if (!ret)
        fprintf(stderr, "all ok!\n");
    return ret;
}