diff --git a/src/vector.h b/src/vector.h index 937df5b53455b01c73c97b1df42dc81933b802fb..0e92734a6d936c4b7dec51ff679b1a1a184c0820 100644 --- a/src/vector.h +++ b/src/vector.h @@ -82,12 +82,13 @@ #define vec_cmp_result(a) a #define vec_form_int_mask(a) a #define vec_and(a, b) _mm512_and_ps(a, b) -#define vec_mask_and(a, b) a & b -#define vec_and_mask(a, mask) _mm512_maskz_expand_ps(mask, a) +#define vec_mask_and(a, b) _mm512_kand(a, b) +#define vec_and_mask(a, mask) _mm512_maskz_expand_ps(mask, a) /* TODO: Alternative needs to be found. */ #define vec_init_mask(mask) mask = 0xFFFF #define vec_zero_mask(mask) mask = 0 #define vec_create_mask(mask, cond) mask = cond #define vec_pad_mask(mask, pad) mask = mask >> (pad) +#define vec_blend(mask, a, b) _mm512_mask_blend_ps(mask, a, b) #define vec_todbl_lo(a) _mm512_cvtps_pd(_mm512_extract128_ps(a, 0)) #define vec_todbl_hi(a) _mm512_cvtps_pd(_mm512_extract128_ps(a, 1)) #define vec_dbl_tofloat(a, b) _mm512_insertf128(_mm512_castps128_ps512(a), b, 1)