algorithm - xmm寄存器 - 查找SSE寄存器中最常出现的元素




sse编程 (2)

有没有人有任何想法如何计算在SSE4.x 8位整数向量的模式(统计)? 为了澄清,这将是128位寄存器中的16x8位值。

我想要结果作为选择模式值元素的矢量蒙版。 即_mm_cmpeq_epi8(v, set1(mode(v))) ,以及标量值。

提供一些额外的背景; 虽然上面的问题本身就是一个有趣的解决方案,但是我经历了大多数线性复杂的算法。 这个班级将从计算这个数字中获得任何收益。

我希望能在这里和大家一起寻找一些深奥的魔法。 有可能需要一个近似来打破这个界限,例如“选择一个经常出现的元素”(例如NB与最多的差异),这是有价值的。 一个概率的答案也是可用的。

SSE和x86有一些非常有趣的语义。 这可能是值得探索一个超级优化的通行证。


这里可能是一个比较简单的强力SSEx方法,请参阅下面的代码。 这个想法是将输入矢量v按字节旋转1到15个位置,并将旋转后的矢量与原始v进行比较以求相等。 为了缩短依赖链并提高指令级并行性,使用两个计数器来计数(垂直和)这些相等的元素: sum1sum2 ,因为可能有架构从中受益。 相等的元素被计为-1。 可变sum = sum1 + sum2包含总数,值在-1和-16之间。 min_brc包含广播到所有元素的水平最小值。 mask = _mm_cmpeq_epi8(sum,min_brc)是作为OP的中间结果请求的模式值元素的掩码。 在接下来的几行代码中,实际模式被提取。

这个解决方案肯定比标量解决方案更快。 请注意,使用AVX2时,可以使用高128位的通道来进一步加速计算。

只需要20个周期(吞吐量)来计算模式值元素的掩码。 在SSE寄存器中广播的实际模式大约需要21.4个周期。

请注意下一个示例中的行为: [1, 1, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]返回mask=[-1,-1,-1,-1,0,0,...,0] ,模式值为1,尽管1经常出现为3。

下面的代码已经过测试,但没有经过彻底的测试

#include <stdio.h>
#include <x86intrin.h>
/*  gcc -O3 -Wall -m64 -march=nehalem mode_uint8.c   */
int print_vec_char(__m128i x);

__m128i mode_statistic(__m128i v){
    __m128i  sum2         = _mm_set1_epi8(-1);                    /* Each integer occurs at least one time */
    __m128i  v_rot1       = _mm_alignr_epi8(v,v,1);
    __m128i  v_rot2       = _mm_alignr_epi8(v,v,2);
    __m128i  sum1         =                   _mm_cmpeq_epi8(v,v_rot1);
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot2));

    __m128i  v_rot3       = _mm_alignr_epi8(v,v,3);
    __m128i  v_rot4       = _mm_alignr_epi8(v,v,4);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot3));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot4));

    __m128i  v_rot5       = _mm_alignr_epi8(v,v,5);
    __m128i  v_rot6       = _mm_alignr_epi8(v,v,6);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot5));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot6));

    __m128i  v_rot7       = _mm_alignr_epi8(v,v,7);
    __m128i  v_rot8       = _mm_alignr_epi8(v,v,8);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot7));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot8));

    __m128i  v_rot9       = _mm_alignr_epi8(v,v,9);
    __m128i  v_rot10      = _mm_alignr_epi8(v,v,10);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot9));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot10));

    __m128i  v_rot11      = _mm_alignr_epi8(v,v,11);
    __m128i  v_rot12      = _mm_alignr_epi8(v,v,12);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot11));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot12));

    __m128i  v_rot13      = _mm_alignr_epi8(v,v,13);
    __m128i  v_rot14      = _mm_alignr_epi8(v,v,14);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot13));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot14));

    __m128i  v_rot15      = _mm_alignr_epi8(v,v,15);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot15));
    __m128i  sum          = _mm_add_epi8(sum1,sum2);                      /* Sum contains values such as -1, -2 ,...,-16                                    */
                                                                          /* The next three instructions compute the horizontal minimum of sum */
    __m128i  sum_shft     = _mm_srli_epi16(sum,8);                        /* Shift right 8 bits, while shifting in zeros                                    */
    __m128i  min1         = _mm_min_epu8(sum,sum_shft);                   /* sum and sum_shuft are considered as unsigned integers. sum_shft is zero at the odd positions and so is min1 */ 
    __m128i  min2         = _mm_minpos_epu16(min1);                       /* Byte 0 within min2 contains the horizontal minimum of sum                      */
    __m128i  min_brc      = _mm_shuffle_epi8(min2,_mm_setzero_si128());   /* Broadcast horizontal minimum                                                   */

    __m128i  mask         = _mm_cmpeq_epi8(sum,min_brc);                  /* Mask = -1 at the byte positions where the value of v is equal to the mode of v */

    /* comment next 4 lines out if there is no need to broadcast the mode value */
    int      bitmask      = _mm_movemask_epi8(mask);
    int      indx         = __builtin_ctz(bitmask);                            /* Index of mode                            */
    __m128i  v_indx       = _mm_set1_epi8(indx);                               /* Broadcast indx                           */
    __m128i  answer       = _mm_shuffle_epi8(v,v_indx);                        /* Broadcast mode to each element of answer */ 

/* Uncomment lines below to print intermediate results, to see how it works. */
//    printf("sum         = ");print_vec_char (sum           );
//    printf("sum_shft    = ");print_vec_char (sum_shft      );
//    printf("min1        = ");print_vec_char (min1          );
//    printf("min2        = ");print_vec_char (min2          );
//    printf("min_brc     = ");print_vec_char (min_brc       );
//    printf("mask        = ");print_vec_char (mask          );
//    printf("v_indx      = ");print_vec_char (v_indx        );
//    printf("answer      = ");print_vec_char (answer        );

             return answer;   /* or return mask, or return both ....    :) */
}


int main() {
    /* To test throughput set throughput_test to 1, otherwise 0    */
    /* Use e.g. perf stat -d ./a.out to test throughput           */
    #define throughput_test 0

    /* Different test vectors  */
    int i;
    char   x1[16] = {5, 2, 2, 7, 21, 4, 7, 7, 3, 9, 2, 5, 4, 3, 5, 5};
    char   x2[16] = {5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5};
    char   x3[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    char   x4[16] = {1, 2, 3, 2, 1, 6, 7, 8, 2, 2, 2, 3, 3, 2, 15, 16};
    char   x5[16] = {1, 1, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

    printf("\n15...0      =   15  14  13  12    11  10  9   8     7   6   5   4     3   2   1   0\n\n");

    __m128i  x_vec  = _mm_loadu_si128((__m128i*)x1);

    printf("x_vec       = ");print_vec_char(x_vec        );

    __m128i  y      = mode_statistic (x_vec);

    printf("answer      = ");print_vec_char(y         );


    #if throughput_test == 1
    __m128i  x_vec1  = _mm_loadu_si128((__m128i*)x1);
    __m128i  x_vec2  = _mm_loadu_si128((__m128i*)x2);
    __m128i  x_vec3  = _mm_loadu_si128((__m128i*)x3);
    __m128i  x_vec4  = _mm_loadu_si128((__m128i*)x4);
    __m128i  x_vec5  = _mm_loadu_si128((__m128i*)x5);
    __m128i  y1, y2, y3, y4, y5;
    __asm__ __volatile__ ( "vzeroupper" : : : );   /* Remove this line on non-AVX processors */
    for (i=0;i<100000000;i++){
        y1       = mode_statistic (x_vec1);
        y2       = mode_statistic (x_vec2);
        y3       = mode_statistic (x_vec3);
        y4       = mode_statistic (x_vec4);
        y5       = mode_statistic (x_vec5);
        x_vec1   = mode_statistic (y1    );
        x_vec2   = mode_statistic (y2    );
        x_vec3   = mode_statistic (y3    );
        x_vec4   = mode_statistic (y4    );
        x_vec5   = mode_statistic (y5    );
     }
    printf("mask mode   = ");print_vec_char(y1           );
    printf("mask mode   = ");print_vec_char(y2           );
    printf("mask mode   = ");print_vec_char(y3           );
    printf("mask mode   = ");print_vec_char(y4           );
    printf("mask mode   = ");print_vec_char(y5           );
    #endif

    return 0;
}



int print_vec_char(__m128i x){
    char v[16];
    _mm_storeu_si128((__m128i *)v,x);
    printf("%3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi\n",
           v[15],v[14],v[13],v[12],v[11],v[10],v[9],v[8],v[7],v[6],v[5],v[4],v[3],v[2],v[1],v[0]);
    return 0;
}

输出:

15...0      =   15  14  13  12    11  10  9   8     7   6   5   4     3   2   1   0

x_vec       =   5   5   3   4 |   5   2   9   3 |   7   7   4  21 |   7   2   2   5
sum         =  -4  -4  -2  -2 |  -4  -3  -1  -2 |  -3  -3  -2  -1 |  -3  -3  -3  -4
min_brc     =  -4  -4  -4  -4 |  -4  -4  -4  -4 |  -4  -4  -4  -4 |  -4  -4  -4  -4
mask        =  -1  -1   0   0 |  -1   0   0   0 |   0   0   0   0 |   0   0   0  -1
answer      =   5   5   5   5 |   5   5   5   5 |   5   5   5   5 |   5   5   5   5

水平最小值用Evgeny Kluev的方法计算 。


对寄存器中的数据进行排序。 插入排序可以在16(15)个步骤中完成,通过将寄存器初始化为“Infinity”,试图说明单调递减的数组并将新元素并行插入到所有可能的位置:

// e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF 78
__m128i sorted = _mm_or_si128(my_array, const_FFFFF00);

for (int i = 1; i < 16; ++i)
{
    // Trying to insert e.g. A0, we must shift all the FF's to left
    // e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF 78 00
    __m128i shifted = _mm_bslli_si128(sorted, 1);

    // Taking the MAX of shifted and 'A0 on all places'
    // e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF A0 A0
    shifted = _mm_max_epu8(shifted, _mm_set1_epi8(my_array[i]));

    // and minimum of the shifted + original --
    // e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF A0 78
    sorted = _mm_min_epu8(sorted, shifted);
}

然后计算vec[n+1] == vec[n]掩码,将掩码移到GPR并使用它来索引32768个条目LUT以获得最佳索引位置。

在实际情况下,人们可能想要排序的不仅仅是一个向量; 即一次对16个16入口向量进行排序;

__m128i input[16];      // not 1, but 16 vectors
transpose16x16(input);  // inplace vector transpose
sort(transpose);        // 60-stage network exists for 16 inputs
// linear search -- result in 'mode'
__m128i mode = input[0];
__m128i previous = mode;
__m128i count = _mm_set_epi8(0);
__m128i max_count = _mm_setzero_si128(0);
for (int i = 1; i < 16; i++)
{
   __m128i &current = input[i];
   // histogram count is off by one
   // if (current == previous) count++;
   //    else count = 0;
   // if (count > max_count)
   //    mode = current, max_count = count
   prev = _mm_cmpeq_epi8(prev, current);
   count = _mm_and_si128(_mm_sub_epi8(count, prev), prev);
   __m128i max_so_far = _mm_cmplt_epi8(max_count, count);
   mode = _mm_blendv_epi8(mode, current, max_so_far);
   max_count = _mm_max_epi8(max_count, count);
   previous = current;
}

内部循环总计每个结果7-8条指令的分摊成本; 分类每阶段通常有2条指令 - 即每个结果8条指令,当16条结果需要60个阶段或120条指令时。 (这仍然是一个运动的转置 - 但我认为它应该比排序快得多?)

所以,这应该是在每8位结果24个指令的球场里。





sse