diff --git a/src/vector.h b/src/vector.h index 19e3121dfb1a5c0af4370b8820b572f343bfe19f..627d82c9f722c2a13decdae57bf91fc3b1622ae7 100644 --- a/src/vector.h +++ b/src/vector.h @@ -60,9 +60,13 @@ #define vec_dbl_set(a, b, c, d, e, f, g, h) \ _mm512_set_pd(h, g, f, e, d, c, b, a) #define vec_add(a, b) _mm512_add_ps(a, b) +#define vec_mask_add(a, b, mask) _mm512_mask_add_ps(a, mask, b, a) #define vec_sub(a, b) _mm512_sub_ps(a, b) +#define vec_mask_sub(a, b, mask) _mm512_mask_sub_ps(a, mask, a, b) #define vec_mul(a, b) _mm512_mul_ps(a, b) +#define vec_div(a, b) _mm512_div_ps(a, b) #define vec_fma(a, b, c) _mm512_fmadd_ps(a, b, c) +#define vec_fnma(a, b, c) _mm512_fnmadd_ps(a, b, c) #define vec_sqrt(a) _mm512_sqrt_ps(a) #define vec_rcp(a) _mm512_rcp14_ps(a) #define vec_rsqrt(a) _mm512_rsqrt14_ps(a) @@ -75,7 +79,15 @@ #define vec_cmp_lt(a, b) _mm512_cmp_ps_mask(a, b, _CMP_LT_OQ) #define vec_cmp_lte(a, b) _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ) #define vec_cmp_gte(a, b) _mm512_cmp_ps_mask(a, b, _CMP_GE_OQ) +#define vec_cmp_result(a) a +#define vec_form_int_mask(a) a #define vec_and(a, b) _mm512_and_ps(a, b) +#define vec_mask_and(a, b) a & b +#define vec_and_mask(a, mask) _mm512_maskz_expand_ps(mask, a) +#define vec_init_mask(mask) mask = 0xFFFF +#define vec_zero_mask(mask) mask = 0 +#define vec_create_mask(mask, cond) mask = cond +#define vec_pad_mask(mask, pad) mask = mask >> (pad) #define vec_todbl_lo(a) _mm512_cvtps_pd(_mm512_extract128_ps(a, 0)) #define vec_todbl_hi(a) _mm512_cvtps_pd(_mm512_extract128_ps(a, 1)) #define vec_dbl_tofloat(a, b) _mm512_insertf128(_mm512_castps128_ps512(a), b, 1) @@ -145,8 +157,11 @@ #define vec_set(a, b, c, d, e, f, g, h) _mm256_set_ps(h, g, f, e, d, c, b, a) #define vec_dbl_set(a, b, c, d) _mm256_set_pd(d, c, b, a) #define vec_add(a, b) _mm256_add_ps(a, b) +#define vec_mask_add(a, b, mask) vec_add(a, vec_and(b, mask.v)) #define vec_sub(a, b) _mm256_sub_ps(a, b) +#define vec_mask_sub(a, b, mask) vec_sub(a, vec_and(b, mask.v)) #define vec_mul(a, b) _mm256_mul_ps(a, b) +#define vec_div(a, b) _mm256_div_ps(a, b) #define vec_sqrt(a) _mm256_sqrt_ps(a) #define vec_rcp(a) _mm256_rcp_ps(a) #define vec_rsqrt(a) _mm256_rsqrt_ps(a) @@ -160,6 +175,7 @@ #define vec_cmp_lte(a, b) _mm256_cmp_ps(a, b, _CMP_LE_OQ) #define vec_cmp_gte(a, b) _mm256_cmp_ps(a, b, _CMP_GE_OQ) #define vec_cmp_result(a) _mm256_movemask_ps(a) +#define vec_form_int_mask(a) _mm256_movemask_ps(a.v) #define vec_and(a, b) _mm256_and_ps(a, b) #define vec_mask_and(a, b) _mm256_and_ps(a.v, b.v) #define vec_and_mask(a, mask) _mm256_and_ps(a, mask.v) @@ -193,6 +209,13 @@ a.v = _mm256_hadd_ps(a.v, a.v); \ b += a.f[0] + a.f[4]; +/* Performs a horizontal maximum on the vector and takes the maximum of the + * result with a float, b. */ +#define VEC_HMAX(a, b) \ + { \ + for (int k = 0; k < VEC_SIZE; k++) b = max(b, a.f[k]); \ + } + /* Returns the lower 128-bits of the 256-bit vector. */ #define VEC_GET_LOW(a) _mm256_castps256_ps128(a) @@ -202,6 +225,7 @@ /* Check if we have AVX2 intrinsics alongside AVX */ #ifdef HAVE_AVX2 #define vec_fma(a, b, c) _mm256_fmadd_ps(a, b, c) +#define vec_fnma(a, b, c) _mm256_fnmadd_ps(a, b, c) /* Used in VEC_FORM_PACKED_MASK */ #define identity_indices 0x0706050403020100 @@ -211,19 +235,18 @@ /* Takes an integer mask and forms a left-packed integer vector * containing indices of the set bits in the integer mask. * Also returns the total number of bits set in the mask. */ -#define VEC_FORM_PACKED_MASK(mask, v_mask, pack) \ +#define VEC_FORM_PACKED_MASK(mask, packed_mask) \ { \ unsigned long expanded_mask = _pdep_u64(mask, 0x0101010101010101); \ expanded_mask *= 0xFF; \ unsigned long wanted_indices = _pext_u64(identity_indices, expanded_mask); \ __m128i bytevec = _mm_cvtsi64_si128(wanted_indices); \ - v_mask = _mm256_cvtepu8_epi32(bytevec); \ - pack += __builtin_popcount(mask); \ + packed_mask.m = _mm256_cvtepu8_epi32(bytevec); \ } /* Performs a left-pack on a vector based upon a mask and returns the result. */ #define VEC_LEFT_PACK(a, mask, result) \ - vec_unaligned_store(_mm256_permutevar8x32_ps(a, mask), result) + vec_unaligned_store(_mm256_permutevar8x32_ps(a, mask.m), result) #endif /* HAVE_AVX2 */ /* Create an FMA using vec_add and vec_mul if AVX2 is not present. */ @@ -231,6 +254,12 @@ #define vec_fma(a, b, c) vec_add(vec_mul(a, b), c) #endif +/* Create a negated FMA using vec_sub and vec_mul if AVX2 is not present. */ +#ifndef vec_fnma +#define vec_fnma(a, b, c) vec_sub(c, vec_mul(a, b)) +#endif + +#define vec_fnma(a, b, c) _mm512_fnmadd_ps(a, b, c) /* Form a packed mask without intrinsics if AVX2 is not present. */ #ifndef VEC_FORM_PACKED_MASK @@ -294,6 +323,7 @@ #define vec_add(a, b) _mm_add_ps(a, b) #define vec_sub(a, b) _mm_sub_ps(a, b) #define vec_mul(a, b) _mm_mul_ps(a, b) +#define vec_div(a, b) _mm_div_ps(a, b) #define vec_sqrt(a) _mm_sqrt_ps(a) #define vec_rcp(a) _mm_rcp_ps(a) #define vec_rsqrt(a) _mm_rsqrt_ps(a) @@ -346,6 +376,13 @@ typedef union { int i[VEC_SIZE]; } vector; +/* Define the mask type depending on the instruction set used. */ +#ifdef HAVE_AVX512_F +typedef __mmask16 mask_t; +#else +typedef vector mask_t; +#endif + /** * @brief Calculates the inverse ($1/x$) of a vector using intrinsics and a * Newton iteration to obtain the correct level of accuracy.