diff --git a/openair1/PHY/INIT/nr_init_ue.c b/openair1/PHY/INIT/nr_init_ue.c
index e5402c1c044315024638fc08e3b9bf0c0bc3e6e8..2e32b7e3e69090c4c8ba0fa28c5e87acca8b952e 100644
--- a/openair1/PHY/INIT/nr_init_ue.c
+++ b/openair1/PHY/INIT/nr_init_ue.c
@@ -349,6 +349,7 @@ int init_nr_ue_signal(PHY_VARS_NR_UE *ue, int nb_connected_gNB)
     ue->nr_csi_rs_info->rank_indicator = (uint8_t*)malloc16_clear(sizeof(uint8_t));
     ue->nr_csi_rs_info->i1 = (uint8_t*)malloc16_clear(3*sizeof(uint8_t));
     ue->nr_csi_rs_info->i2 = (uint8_t*)malloc16_clear(sizeof(uint8_t));
+    ue->nr_csi_rs_info->precoded_sinr = (uint32_t*)malloc16_clear(sizeof(uint32_t));
     ue->nr_csi_rs_info->csi_rs_generated_signal = (int32_t **)malloc16(NR_MAX_NB_PORTS * sizeof(int32_t *) );
     for (i=0; i<NR_MAX_NB_PORTS; i++) {
       ue->nr_csi_rs_info->csi_rs_generated_signal[i] = (int32_t *) malloc16_clear(fp->samples_per_frame_wCP * sizeof(int32_t));
diff --git a/openair1/PHY/NR_UE_TRANSPORT/csi_rx.c b/openair1/PHY/NR_UE_TRANSPORT/csi_rx.c
index 4e3ac6dd69d889f703055bf69ad7e2be7830d0e3..7eb0b68307fe270e87f258820a2a28dc87e9f11a 100644
--- a/openair1/PHY/NR_UE_TRANSPORT/csi_rx.c
+++ b/openair1/PHY/NR_UE_TRANSPORT/csi_rx.c
@@ -206,6 +206,7 @@ int nr_csi_rs_channel_estimation(PHY_VARS_NR_UE *ue,
   int dataF_offset = proc->nr_slot_rx*ue->frame_parms.samples_per_slot_wCP;
   *noise_power = 0;
   int maxh = 0;
+  int count = 0;
 
   for (int ant_rx = 0; ant_rx < frame_parms->nb_antennas_rx; ant_rx++) {
 
@@ -291,6 +292,8 @@ int nr_csi_rs_channel_estimation(PHY_VARS_NR_UE *ue,
         continue;
       }
 
+      count++;
+
       uint16_t k = (frame_parms->first_carrier_offset + rb*NR_NB_SC_PER_RB) % frame_parms->ofdm_symbol_size;
       for(uint16_t port_tx = 0; port_tx<nr_csi_rs_info->N_ports; port_tx++) {
         int16_t *csi_rs_ls_estimated_channel = (int16_t*)&nr_csi_rs_info->csi_rs_ls_estimated_channel[ant_rx][port_tx][k];
@@ -350,6 +353,7 @@ int nr_csi_rs_channel_estimation(PHY_VARS_NR_UE *ue,
 
   *noise_power /= (frame_parms->nb_antennas_rx*nr_csi_rs_info->N_ports);
   nr_csi_rs_info->log2_maxh = log2_approx(maxh-1);
+  nr_csi_rs_info->log2_re = log2_approx(count-1);
 
 #ifdef NR_CSIRS_DEBUG
   LOG_I(NR_PHY, "Noise power estimation based on CSI-RS: %i\n", *noise_power);
@@ -463,9 +467,11 @@ int nr_csi_rs_pmi_estimation(PHY_VARS_NR_UE *ue,
                              fapi_nr_dl_config_csirs_pdu_rel15_t *csirs_config_pdu,
                              nr_csi_rs_info_t *nr_csi_rs_info,
                              int32_t ***csi_rs_estimated_channel_freq,
+                             uint32_t noise_power,
                              uint8_t rank_indicator,
                              uint8_t *i1,
-                             uint8_t *i2) {
+                             uint8_t *i2,
+                             uint32_t *precoded_sinr) {
 
   NR_DL_FRAME_PARMS *frame_parms = &ue->frame_parms;
   memset(i1,0,3*sizeof(uint8_t));
@@ -480,13 +486,18 @@ int nr_csi_rs_pmi_estimation(PHY_VARS_NR_UE *ue,
   // The first column is applicable if the UE is reporting a Rank = 1, whereas the second column is applicable if the
   // UE is reporting a Rank = 2.
 
-  if(ue->nr_csi_rs_info->N_ports == 1) {
+  if(nr_csi_rs_info->N_ports == 1 || noise_power == 0) {
     return 0;
   }
 
   if(rank_indicator == 0 || rank_indicator == 1) {
 
-    uint32_t metric[4] = {0};
+    int32_t sum_re[4] = {0};
+    int32_t sum_im[4] = {0};
+    int32_t sum2_re[4] = {0};
+    int32_t sum2_im[4] = {0};
+    int32_t tested_precoded_sinr[4] = {0};
+
     uint16_t shift = nr_csi_rs_info->log2_maxh + log2_approx(frame_parms->nb_antennas_rx) + log2_approx(csirs_config_pdu->nr_of_rbs) - 1;
 
     for (int rb = csirs_config_pdu->start_rb; rb < (csirs_config_pdu->start_rb+csirs_config_pdu->nr_of_rbs); rb++) {
@@ -497,35 +508,54 @@ int nr_csi_rs_pmi_estimation(PHY_VARS_NR_UE *ue,
       uint16_t k = (frame_parms->first_carrier_offset + rb * NR_NB_SC_PER_RB) % frame_parms->ofdm_symbol_size;
 
       for (int ant_rx = 0; ant_rx < frame_parms->nb_antennas_rx; ant_rx++) {
+
         int16_t *csi_rs_estimated_channel_p0 = (int16_t *) &csi_rs_estimated_channel_freq[ant_rx][0][k];
         int16_t *csi_rs_estimated_channel_p1 = (int16_t *) &csi_rs_estimated_channel_freq[ant_rx][1][k];
 
         // H_p0 + 1*H_p1 = (H_p0_re + H_p1_re) + 1j*(H_p0_im + H_p1_im)
-        metric[0] += ((csi_rs_estimated_channel_p0[0]+csi_rs_estimated_channel_p1[0])*(csi_rs_estimated_channel_p0[0]+csi_rs_estimated_channel_p1[0])
-            + (csi_rs_estimated_channel_p0[1]+csi_rs_estimated_channel_p1[1])*(csi_rs_estimated_channel_p0[1]+csi_rs_estimated_channel_p1[1]))>>shift;
+        sum_re[0] += (csi_rs_estimated_channel_p0[0]+csi_rs_estimated_channel_p1[0]);
+        sum_im[0] += (csi_rs_estimated_channel_p0[1]+csi_rs_estimated_channel_p1[1]);
+        sum2_re[0] += ((csi_rs_estimated_channel_p0[0]+csi_rs_estimated_channel_p1[0])*(csi_rs_estimated_channel_p0[0]+csi_rs_estimated_channel_p1[0]))>>nr_csi_rs_info->log2_re;
+        sum2_im[0] += ((csi_rs_estimated_channel_p0[1]+csi_rs_estimated_channel_p1[1])*(csi_rs_estimated_channel_p0[1]+csi_rs_estimated_channel_p1[1]))>>nr_csi_rs_info->log2_re;
 
         // H_p0 + 1j*H_p1 = (H_p0_re - H_p1_im) + 1j*(H_p0_im + H_p1_re)
-        metric[1] += ((csi_rs_estimated_channel_p0[0]-csi_rs_estimated_channel_p1[1])*(csi_rs_estimated_channel_p0[0]-csi_rs_estimated_channel_p1[1])
-                      + (csi_rs_estimated_channel_p0[1]+csi_rs_estimated_channel_p1[0])*(csi_rs_estimated_channel_p0[1]+csi_rs_estimated_channel_p1[0]))>>shift;
+        sum_re[1] += (csi_rs_estimated_channel_p0[0]-csi_rs_estimated_channel_p1[1]);
+        sum_im[1] += (csi_rs_estimated_channel_p0[1]+csi_rs_estimated_channel_p1[0]);
+        sum2_re[1] += ((csi_rs_estimated_channel_p0[0]-csi_rs_estimated_channel_p1[1])*(csi_rs_estimated_channel_p0[0]-csi_rs_estimated_channel_p1[1]))>>nr_csi_rs_info->log2_re;
+        sum2_im[1] += ((csi_rs_estimated_channel_p0[1]+csi_rs_estimated_channel_p1[0])*(csi_rs_estimated_channel_p0[1]+csi_rs_estimated_channel_p1[0]))>>nr_csi_rs_info->log2_re;
 
         // H_p0 - 1*H_p1 = (H_p0_re - H_p1_re) + 1j*(H_p0_im - H_p1_im)
-        metric[2] += ((csi_rs_estimated_channel_p0[0]-csi_rs_estimated_channel_p1[0])*(csi_rs_estimated_channel_p0[0]-csi_rs_estimated_channel_p1[0])
-                      + (csi_rs_estimated_channel_p0[1]-csi_rs_estimated_channel_p1[1])*(csi_rs_estimated_channel_p0[1]-csi_rs_estimated_channel_p1[1]))>>shift;
+        sum_re[2] += (csi_rs_estimated_channel_p0[0]-csi_rs_estimated_channel_p1[0]);
+        sum_im[2] += (csi_rs_estimated_channel_p0[1]-csi_rs_estimated_channel_p1[1]);
+        sum2_re[2] += ((csi_rs_estimated_channel_p0[0]-csi_rs_estimated_channel_p1[0])*(csi_rs_estimated_channel_p0[0]-csi_rs_estimated_channel_p1[0]))>>nr_csi_rs_info->log2_re;
+        sum2_im[2] += ((csi_rs_estimated_channel_p0[1]-csi_rs_estimated_channel_p1[1])*(csi_rs_estimated_channel_p0[1]-csi_rs_estimated_channel_p1[1]))>>nr_csi_rs_info->log2_re;
 
         // H_p0 - 1j*H_p1 = (H_p0_re + H_p1_im) + 1j*(H_p0_im - H_p1_re)
-        metric[3] += ((csi_rs_estimated_channel_p0[0]+csi_rs_estimated_channel_p1[1])*(csi_rs_estimated_channel_p0[0]+csi_rs_estimated_channel_p1[1])
-                      + (csi_rs_estimated_channel_p0[1]-csi_rs_estimated_channel_p1[0])*(csi_rs_estimated_channel_p0[1]-csi_rs_estimated_channel_p1[0]))>>shift;
+        sum_re[3] += (csi_rs_estimated_channel_p0[0]+csi_rs_estimated_channel_p1[1]);
+        sum_im[3] += (csi_rs_estimated_channel_p0[1]-csi_rs_estimated_channel_p1[0]);
+        sum2_re[3] += ((csi_rs_estimated_channel_p0[0]+csi_rs_estimated_channel_p1[1])*(csi_rs_estimated_channel_p0[0]+csi_rs_estimated_channel_p1[1]))>>nr_csi_rs_info->log2_re;
+        sum2_im[3] += ((csi_rs_estimated_channel_p0[1]-csi_rs_estimated_channel_p1[0])*(csi_rs_estimated_channel_p0[1]-csi_rs_estimated_channel_p1[0]))>>nr_csi_rs_info->log2_re;
       }
     }
 
+    // We should perform >>nr_csi_rs_info->log2_re here for all terms, but since sum2_re and sum2_im can be high values,
+    // we performed this above.
+    for(int p = 0; p<4; p++) {
+      int32_t power_re = sum2_re[p] - (sum_re[p]>>nr_csi_rs_info->log2_re)*(sum_re[p]>>nr_csi_rs_info->log2_re);
+      int32_t power_im = sum2_im[p] - (sum_im[p]>>nr_csi_rs_info->log2_re)*(sum_im[p]>>nr_csi_rs_info->log2_re);
+      tested_precoded_sinr[p] = (power_re+power_im)/(int32_t)noise_power;
+    }
+
     if(rank_indicator == 0) {
       for(int tested_i2 = 0; tested_i2 < 4; tested_i2++) {
-        if(metric[tested_i2] > metric[i2[0]]) {
+        if(tested_precoded_sinr[tested_i2] > tested_precoded_sinr[i2[0]]) {
           i2[0] = tested_i2;
         }
       }
+      *precoded_sinr = tested_precoded_sinr[i2[0]];
     } else {
-      i2[0] = metric[0]+metric[2] > metric[1]+metric[3] ? 0 : 1;
+      i2[0] = tested_precoded_sinr[0]+tested_precoded_sinr[2] > tested_precoded_sinr[1]+tested_precoded_sinr[3] ? 0 : 1;
+      *precoded_sinr = (tested_precoded_sinr[i2[0]] + tested_precoded_sinr[i2[0]+2])>>1;
     }
 
   } else {
@@ -598,9 +628,11 @@ int nr_ue_csi_rs_procedures(PHY_VARS_NR_UE *ue, UE_nr_rxtx_proc_t *proc, uint8_t
                            csirs_config_pdu,
                            ue->nr_csi_rs_info,
                            ue->nr_csi_rs_info->csi_rs_estimated_channel_freq,
+                           *ue->nr_csi_rs_info->noise_power,
                            *ue->nr_csi_rs_info->rank_indicator,
                            ue->nr_csi_rs_info->i1,
-                           ue->nr_csi_rs_info->i2);
+                           ue->nr_csi_rs_info->i2,
+                           ue->nr_csi_rs_info->precoded_sinr);
 
   LOG_I(NR_PHY, "RI = %i, i1 = %i.%i.%i, i2 = %i\n",
         *ue->nr_csi_rs_info->rank_indicator + 1,
diff --git a/openair1/PHY/defs_nr_common.h b/openair1/PHY/defs_nr_common.h
index c1157e452f3964efab7d977c767e9affc3f7f641..4616a87f1682caf6ee5cadd2aed710c5cd37c425 100644
--- a/openair1/PHY/defs_nr_common.h
+++ b/openair1/PHY/defs_nr_common.h
@@ -261,6 +261,7 @@ typedef struct {
   int32_t **csi_rs_received_signal;
   int32_t ***csi_rs_ls_estimated_channel;
   int32_t ***csi_rs_estimated_channel_freq;
+  int16_t log2_re;
   int16_t log2_maxh;
   int32_t csi_rs_estimated_conjch_ch[4][4][4][4][NR_MAX_OFDM_SYMBOL_SIZE] __attribute__((aligned(32)));
   int32_t csi_rs_estimated_A_MF[2][2][NR_MAX_OFDM_SYMBOL_SIZE] __attribute__((aligned(32)));
@@ -271,6 +272,7 @@ typedef struct {
   uint8_t *rank_indicator;
   uint8_t *i1;
   uint8_t *i2;
+  uint32_t *precoded_sinr;
 } nr_csi_rs_info_t;
 
 typedef struct NR_DL_FRAME_PARMS NR_DL_FRAME_PARMS;