#include <immintrin.h> #include <stdint.h> #include <stdio.h> const uint8_t lut[256] = { 255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16, 15, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 9, 9, 8, 8, 8, 8, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, }; int main(int argc, char *argv[]) { int ret = 0; // test xmm rcpps for (int z = 0; z < 4095; z++) { const __m128 f256 = _mm_set1_ps(256.0f); union { int32_t i[4]; __m128i m; } buf; __m128 zp1 = _mm_set1_ps(z + 1); __m128 recip = _mm_mul_ps(_mm_rcp_ps(zp1), f256); __m128i clip = _mm_srli_epi32(_mm_cmpgt_epi32(_mm_castps_si128(f256), _mm_castps_si128(zp1)), 24); __m128i res = _mm_min_epi16(_mm_cvtps_epi32(recip), clip); _mm_store_si128(&buf.m, res); int a = lut[z > 255 ? 255 : z]; for (int i = 0; i < 4; i++) { int b = buf.i[i]; if (a != b) { fprintf(stderr, "xmm mismatch z=%d a=%d b=%d\n", z, a, b); ret = 1; break; } } } // test ymm rcpps for (int z = 0; z < 4095; z++) { const __m256 f256 = _mm256_set1_ps(256.0f); union { int32_t i[8]; __m256i m; } buf; __m256 zp1 = _mm256_set1_ps(z + 1); __m256 recip = _mm256_mul_ps(_mm256_rcp_ps(zp1), f256); __m256i clip = _mm256_srli_epi32(_mm256_cmpgt_epi32(_mm256_castps_si256(f256), _mm256_castps_si256(zp1)), 24); __m256i res = _mm256_min_epi16(_mm256_cvtps_epi32(recip), clip); _mm256_store_si256(&buf.m, res); int a = lut[z > 255 ? 255 : z]; for (int i = 0; i < 8; i++) { int b = buf.i[i]; if (a != b) { fprintf(stderr, "ymm mismatch z=%d a=%d b=%d\n", z, a, b); ret = 1; break; } } } if (!ret) fprintf(stderr, "all ok!\n"); return ret; }