Commit 76ebc5ad authored by knopp's avatar knopp

addition of ARM NEON intrinsics

parent 88d5bf42
Pipeline #10776 failed with stage
in 0 seconds
......@@ -254,6 +254,18 @@ void build_decoder_tree(t_nrPolar_params *pp) {
}
#if defined(__arm__) || defined(__aarch64__)
// translate 1-1 SIMD functions from SSE to NEON
#define __m128i int16x8_t
#define __m64 int8x8_t
#define _mm_abs_epi16(a) vabsq_s16(a)
#define _mm_min_epi16(a,b) vminq_s16(a,b)
#define _mm_subs_epi16(a,b) vsubq_s16(a,b)
#define _mm_abs_pi16(a) vabs_s16(a)
#define _mm_min_pi16(a,b) vmin_s16(a,b)
#define _mm_subs_pi16(a,b) vsub_s16(a,b)
#endif
void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) {
int16_t *alpha_v=node->alpha;
int16_t *alpha_l=node->left->alpha;
......@@ -270,7 +282,6 @@ void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) {
if (node->left->all_frozen == 0) {
#if defined(__AVX2__)
int avx2mod = (node->Nv/2)&15;
if (avx2mod == 0) {
......@@ -284,14 +295,7 @@ void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) {
absa256 =_mm256_abs_epi16(a256);
absb256 =_mm256_abs_epi16(b256);
minabs256 =_mm256_min_epi16(absa256,absb256);
((__m256i*)alpha_l)[i] =_mm256_sign_epi16(minabs256,_mm256_xor_si256(a256,b256));
/* for (int j=0;j<16;j++) printf("alphal[%d] %d (%d,%d,%d)\n",
(16*i) + j,
alpha_l[(16*i)+j],
((int16_t*)&minabs256)[j],
alpha_v[(16*i)+j],
alpha_v[(16*i)+j+(node->Nv/2)]);
*/
((__m256i*)alpha_l)[i] =_mm256_sign_epi16(minabs256,_mm256_sign_epi16(a256,b256));
}
}
else if (avx2mod == 8) {
......@@ -301,7 +305,7 @@ void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) {
absa128 =_mm_abs_epi16(a128);
absb128 =_mm_abs_epi16(b128);
minabs128 =_mm_min_epi16(absa128,absb128);
*((__m128i*)alpha_l) =_mm_sign_epi16(minabs128,_mm_xor_si128(a128,b128));
*((__m128i*)alpha_l) =_mm_sign_epi16(minabs128,_mm_sign_epi16(a128,b128));
}
else if (avx2mod == 4) {
__m64 a64,b64,absa64,absb64,minabs64;
......@@ -310,11 +314,56 @@ void applyFtoleft(t_nrPolar_params *pp,decoder_node_t *node) {
absa64 =_mm_abs_pi16(a64);
absb64 =_mm_abs_pi16(b64);
minabs64 =_mm_min_pi16(absa64,absb64);
*((__m64*)alpha_l) =_mm_sign_pi16(minabs64,_mm_xor_si64(a64,b64));
*((__m64*)alpha_l) =_mm_sign_pi16(minabs64,_mm_sign_pi16(a64,b64));
}
else
#else
int sse4mod = (node->Nv/2)&7;
int sse4len = node->Nv/2/8;
#if defined(__arm__) || defined(__aarch64__)
int16x8_t signatimesb,comp1,comp2,negminabs128;
int16x8_t zero=vdupq_n_s16(0);
#endif
if (sse4mod == 0) {
for (int i=0;i<sse4len;i++) {
__m128i a128,b128,absa128,absb128,minabs128;
int sse4len = node->Nv/2/8;
a128 =*((__m128i*)alpha_v);
b128 =((__m128i*)alpha_v)[1];
absa128 =_mm_abs_epi16(a128);
absb128 =_mm_abs_epi16(b128);
minabs128 =_mm_min_epi16(absa128,absb128);
#if defined(__arm__) || defined(__aarch64__)
// unfortunately no direct equivalent to _mm_sign_epi16
signatimesb=vxorrq_s16(a128,b128);
comp1=vcltq_s16(signatimesb,zero);
comp2=vcgeq_s16(signatimesb,zero);
negminabs128=vnegq_s16(minabs128);
*((__m128i*)alpha_l) =vorrq_s16(vandq_s16(minabs128,comp0),vandq_s16(negminabs128,comp1));
#else
*((__m128i*)alpha_l) =_mm_sign_epi16(minabs128,_mm_sign_epi16(a128,b128));
#endif
}
}
else if (sse4mod == 4) {
__m64 a64,b64,absa64,absb64,minabs64;
a64 =*((__m64*)alpha_v);
b64 =((__m64*)alpha_v)[1];
absa64 =_mm_abs_pi16(a64);
absb64 =_mm_abs_pi16(b64);
minabs64 =_mm_min_pi16(absa64,absb64);
#if defined(__arm__) || defined(__aarch64__)
AssertFatal(1==0,"Need to do this still for ARM\n");
#else
*((__m64*)alpha_l) =_mm_sign_pi16(minabs64,_mm_sign_epi16(a64,b64));
#endif
}
else
#endif
{
{ // equvalent scalar code to above, activated only on non x86/ARM architectures
for (int i=0;i<node->Nv/2;i++) {
a=alpha_v[i];
b=alpha_v[i+(node->Nv/2)];
......@@ -367,9 +416,34 @@ void applyGtoright(t_nrPolar_params *pp,decoder_node_t *node) {
else if (avx2mod == 8) {
((__m128i *)alpha_r)[0] = _mm_subs_epi16(((__m128i *)alpha_v)[1],_mm_sign_epi16(((__m128i *)alpha_v)[0],((__m128i *)betal)[0]));
}
else if (avx2mod == 4) {
((__m64 *)alpha_r)[0] = _mm_subs_pi16(((__m64 *)alpha_v)[1],_mm_sign_pi16(((__m64 *)alpha_v)[0],((__m64 *)betal)[0]));
}
else
#else
int sse4mod = (node->Nv/2)&7;
if (sse4mod == 0) {
int sse4len = node->Nv/2/8;
for (int i=0;i<sse4len;i++) {
#if defined(__arm__) || defined(__aarch64__)
((int16x8_t *)alpha_r)[0] = vsubq_s16(((int16x8_t *)alpha_v)[1],vmulq_epi16(((int16x8_t *)alpha_v)[0],((int16x8_t *)betal)[0]));
#else
((__m128i *)alpha_r)[0] = _mm_subs_epi16(((__m128i *)alpha_v)[1],_mm_sign_epi16(((__m128i *)alpha_v)[0],((__m128i *)betal)[0]));
#endif
}
}
else if (sse4mod == 4) {
#if defined(__arm__) || defined(__aarch64__)
((int16x4_t *)alpha_r)[0] = vsub_s16(((int16x4_t *)alpha_v)[1],vmul_epi16(((int16x4_t *)alpha_v)[0],((int16x4_t *)betal)[0]));
#else
((__m64 *)alpha_r)[0] = _mm_subs_pi16(((__m64 *)alpha_v)[1],_mm_sign_pi16(((__64 *)alpha_v)[0],((__m64 *)betal)[0]));
#endif
}
else
#endif
{
{// equvalent scalar code to above, activated only on non x86/ARM architectures
for (int i=0;i<node->Nv/2;i++) {
alpha_r[i] = alpha_v[i+(node->Nv/2)] - (betal[i]*alpha_v[i]);
}
......@@ -385,10 +459,10 @@ void applyGtoright(t_nrPolar_params *pp,decoder_node_t *node) {
}
int16_t minus1[16] = {-1,-1,-1,-1,
-1,-1,-1,-1,
-1,-1,-1,-1,
-1,-1,-1,-1};
int16_t all1[16] = {1,1,1,1,
1,1,1,1,
1,1,1,1,
1,1,1,1};
void computeBeta(t_nrPolar_params *pp,decoder_node_t *node) {
......@@ -401,27 +475,37 @@ void computeBeta(t_nrPolar_params *pp,decoder_node_t *node) {
if (node->left->all_frozen==0) { // if left node is not aggregation of frozen bits
#if defined(__AVX2__)
int avx2mod = (node->Nv/2)&15;
register __m256i allones=*((__m256i*)all1);
if (avx2mod == 0) {
int avx2len = node->Nv/2/16;
for (int i=0;i<avx2len;i++) {
((__m256i*)betav)[i] = _mm256_sign_epi16(((__m256i*)betar)[i],
((__m256i*)betal)[i]);
((__m256i*)betav)[i] = _mm256_sign_epi16(((__m256i*)betav)[i],
((__m256i*)minus1)[0]);
((__m256i*)betav)[i] = _mm256_or_si256(_mm256_cmpeq_epi16(((__m256i*)betar)[i],
((__m256i*)betal)[i]),allones);
}
}
else if (avx2mod == 8) {
((__m128i*)betav)[0] = _mm_sign_epi16(((__m128i*)betar)[0],
((__m128i*)betal)[0]);
((__m128i*)betav)[0] = _mm_sign_epi16(((__m128i*)betav)[0],
((__m128i*)minus1)[0]);
((__m128i*)betav)[0] = _mm_or_si128(_mm_cmpeq_epi16(((__m128i*)betar)[0],
((__m128i*)betal)[0]),*((__m128i*)all1));
}
else if (avx2mod == 4) {
((__m64*)betav)[0] = _mm_sign_pi16(((__m64*)betar)[0],
((__m64*)betal)[0]);
((__m64*)betav)[0] = _mm_sign_pi16(((__m64*)betav)[0],
((__m64*)minus1)[0]);
((__m64*)betav)[0] = _mm_or_si64(_mm_cmpeq_pi16(((__m64*)betar)[0],
((__m64*)betal)[0]),*((__m64*)all1));
}
else
#else
int avx2mod = (node->Nv/2)&15;
if (ssr4mod == 0) {
int ssr4len = node->Nv/2/8;
register __m128i allones=*((__m128i*)all1);
for (int i=0;i<sse4len;i++) {
((__m256i*)betav)[i] = _mm_or_si128(_mm_cmpeq_epi16(((__m128i*)betar)[i],
((__m128i*)betal)[i]),allones));
}
}
else if (sse4mod == 4) {
((__m64*)betav)[0] = _mm_or_si64(_mm_cmpeq_pi16(((__m64*)betar)[0],
((__m64*)betal)[0]),*((__m64*)all1));
}
else
#endif
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment