algorithm sémantique Recherche de l'élément le plus fréquent dans un registre SSE




lexico-sémantique pdf (3)

Trier les données dans le registre. Le tri par insertion peut être fait en 16 (15) étapes, en initialisant le registre à "Infinity", qui essaie d'illustrer un tableau décroissant de façon monotone et en insérant le nouvel élément en parallèle à tous les endroits possibles:

// e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF FF 78
__m128i sorted = _mm_or_si128(my_array, const_FFFFF00);

for (int i = 1; i < 16; ++i)
{
    // Trying to insert e.g. A0, we must shift all the FF's to left
    // e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF 78 00
    __m128i shifted = _mm_bslli_si128(sorted, 1);

    // Taking the MAX of shifted and 'A0 on all places'
    // e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF A0 A0
    shifted = _mm_max_epu8(shifted, _mm_set1_epi8(my_array[i]));

    // and minimum of the shifted + original --
    // e.g. FF FF FF FF FF FF FF FF FF FF FF FF FF FF A0 78
    sorted = _mm_min_epu8(sorted, shifted);
}

Puis calculez le masque pour vec[n+1] == vec[n] , déplacez le masque vers GPR et utilisez-le pour indexer une LUT d'entrée 32768 pour le meilleur emplacement d'index.

Dans le cas réel, on veut probablement trier plus qu'un seul vecteur; à savoir trier 16 vecteurs à 16 entrées à la fois;

__m128i input[16];      // not 1, but 16 vectors
transpose16x16(input);  // inplace vector transpose
sort(transpose);        // 60-stage network exists for 16 inputs
// linear search -- result in 'mode'
__m128i mode = input[0];
__m128i previous = mode;
__m128i count = _mm_set_epi8(0);
__m128i max_count = _mm_setzero_si128(0);
for (int i = 1; i < 16; i++)
{
   __m128i &current = input[i];
   // histogram count is off by one
   // if (current == previous) count++;
   //    else count = 0;
   // if (count > max_count)
   //    mode = current, max_count = count
   prev = _mm_cmpeq_epi8(prev, current);
   count = _mm_and_si128(_mm_sub_epi8(count, prev), prev);
   __m128i max_so_far = _mm_cmplt_epi8(max_count, count);
   mode = _mm_blendv_epi8(mode, current, max_so_far);
   max_count = _mm_max_epi8(max_count, count);
   previous = current;
}

La boucle interne totalise le coût amorti de 7-8 instructions par résultat; Le tri comporte généralement 2 instructions par étape, soit 8 instructions par résultat, alors que 16 résultats nécessitent 60 étapes ou 120 instructions. (Cela laisse toujours la transposition comme un exercice - mais je pense que cela devrait être beaucoup plus rapide que le tri?)

Donc, cela devrait être dans le parc de balles de 24 instructions par résultat de 8 bits.

Est-ce que quelqu'un a des idées sur la façon de calculer le mode (statistique) d'un vecteur d'entiers de 8 bits dans SSE4.x? Pour clarifier, ce serait des valeurs de 16x8 bits dans un registre de 128 bits.

Je veux que le résultat soit un masque vectoriel qui sélectionne les éléments à valeur de mode. c'est-à-dire le résultat de _mm_cmpeq_epi8(v, set1(mode(v))) , ainsi que la valeur scalaire.

Fournir un contexte supplémentaire Bien que le problème ci-dessus soit intéressant à résoudre en soi, j'ai été confronté à la plupart des algorithmes auxquels je peux penser avec une complexité linéaire. Cette classe effacera tous les gains que je peux obtenir en calculant ce nombre.

J'espère vous engager tous dans la recherche d'une magie profonde, ici. Il est possible qu'une approximation soit nécessaire pour rompre cette limite, comme par exemple "sélectionner un élément fréquent" (différence NB contre plus ), ce qui serait un mérite. Une réponse probabiliste serait également utilisable.

SSE et x86 ont une sémantique très intéressante. Il peut être intéressant d'explorer une passe de superoptimisation.


Probablement une approche SSEx force brute relativement simple est appropriée ici, voir le code ci-dessous. L'idée est de faire une rotation octet du vecteur d'entrée v de 1 à 15 positions et de comparer le vecteur pivoté avec l'original v pour l'égalité. Pour raccourcir la chaîne de dépendances et augmenter le parallélisme du niveau d'instruction, deux compteurs sont utilisés pour compter (somme verticale) ces éléments égaux: sum1 et sum2 , car il pourrait y avoir des architectures qui en bénéficient. Les éléments égaux sont comptés comme -1. La sum = sum1 + sum2 variables sum = sum1 + sum2 contient le nombre total avec des valeurs comprises entre -1 et -16. min_brc contient le minimum horizontal de la sum diffusée à tous les éléments. mask = _mm_cmpeq_epi8(sum,min_brc) est le masque pour les éléments à valeur de mode demandé comme résultat intermédiaire par l'OP. Dans les quelques lignes suivantes du code, le mode réel est extrait.

Cette solution est certainement plus rapide qu'une solution scalaire. Notez qu'avec AVX2 les voies 128 bits supérieures peuvent être utilisées pour accélérer le calcul plus loin.

Il faut 20 cycles (débit) pour calculer seulement un masque pour les éléments à valeur de mode. Avec le mode réel diffusé sur le registre SSE, il faut environ 21,4 cycles.

Notez le comportement dans l'exemple suivant: [1, 1, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] renvoie mask=[-1,-1,-1,-1,0,0,...,0] et la valeur du mode est 1, bien que 1 apparaisse aussi souvent que 3.

Le code ci-dessous est testé, mais pas complètement testé

#include <stdio.h>
#include <x86intrin.h>
/*  gcc -O3 -Wall -m64 -march=nehalem mode_uint8.c   */
int print_vec_char(__m128i x);

__m128i mode_statistic(__m128i v){
    __m128i  sum2         = _mm_set1_epi8(-1);                    /* Each integer occurs at least one time */
    __m128i  v_rot1       = _mm_alignr_epi8(v,v,1);
    __m128i  v_rot2       = _mm_alignr_epi8(v,v,2);
    __m128i  sum1         =                   _mm_cmpeq_epi8(v,v_rot1);
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot2));

    __m128i  v_rot3       = _mm_alignr_epi8(v,v,3);
    __m128i  v_rot4       = _mm_alignr_epi8(v,v,4);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot3));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot4));

    __m128i  v_rot5       = _mm_alignr_epi8(v,v,5);
    __m128i  v_rot6       = _mm_alignr_epi8(v,v,6);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot5));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot6));

    __m128i  v_rot7       = _mm_alignr_epi8(v,v,7);
    __m128i  v_rot8       = _mm_alignr_epi8(v,v,8);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot7));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot8));

    __m128i  v_rot9       = _mm_alignr_epi8(v,v,9);
    __m128i  v_rot10      = _mm_alignr_epi8(v,v,10);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot9));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot10));

    __m128i  v_rot11      = _mm_alignr_epi8(v,v,11);
    __m128i  v_rot12      = _mm_alignr_epi8(v,v,12);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot11));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot12));

    __m128i  v_rot13      = _mm_alignr_epi8(v,v,13);
    __m128i  v_rot14      = _mm_alignr_epi8(v,v,14);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot13));
             sum2         = _mm_add_epi8(sum2,_mm_cmpeq_epi8(v,v_rot14));

    __m128i  v_rot15      = _mm_alignr_epi8(v,v,15);
             sum1         = _mm_add_epi8(sum1,_mm_cmpeq_epi8(v,v_rot15));
    __m128i  sum          = _mm_add_epi8(sum1,sum2);                      /* Sum contains values such as -1, -2 ,...,-16                                    */
                                                                          /* The next three instructions compute the horizontal minimum of sum */
    __m128i  sum_shft     = _mm_srli_epi16(sum,8);                        /* Shift right 8 bits, while shifting in zeros                                    */
    __m128i  min1         = _mm_min_epu8(sum,sum_shft);                   /* sum and sum_shuft are considered as unsigned integers. sum_shft is zero at the odd positions and so is min1 */ 
    __m128i  min2         = _mm_minpos_epu16(min1);                       /* Byte 0 within min2 contains the horizontal minimum of sum                      */
    __m128i  min_brc      = _mm_shuffle_epi8(min2,_mm_setzero_si128());   /* Broadcast horizontal minimum                                                   */

    __m128i  mask         = _mm_cmpeq_epi8(sum,min_brc);                  /* Mask = -1 at the byte positions where the value of v is equal to the mode of v */

    /* comment next 4 lines out if there is no need to broadcast the mode value */
    int      bitmask      = _mm_movemask_epi8(mask);
    int      indx         = __builtin_ctz(bitmask);                            /* Index of mode                            */
    __m128i  v_indx       = _mm_set1_epi8(indx);                               /* Broadcast indx                           */
    __m128i  answer       = _mm_shuffle_epi8(v,v_indx);                        /* Broadcast mode to each element of answer */ 

/* Uncomment lines below to print intermediate results, to see how it works. */
//    printf("sum         = ");print_vec_char (sum           );
//    printf("sum_shft    = ");print_vec_char (sum_shft      );
//    printf("min1        = ");print_vec_char (min1          );
//    printf("min2        = ");print_vec_char (min2          );
//    printf("min_brc     = ");print_vec_char (min_brc       );
//    printf("mask        = ");print_vec_char (mask          );
//    printf("v_indx      = ");print_vec_char (v_indx        );
//    printf("answer      = ");print_vec_char (answer        );

             return answer;   /* or return mask, or return both ....    :) */
}


int main() {
    /* To test throughput set throughput_test to 1, otherwise 0    */
    /* Use e.g. perf stat -d ./a.out to test throughput           */
    #define throughput_test 0

    /* Different test vectors  */
    int i;
    char   x1[16] = {5, 2, 2, 7, 21, 4, 7, 7, 3, 9, 2, 5, 4, 3, 5, 5};
    char   x2[16] = {5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5};
    char   x3[16] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
    char   x4[16] = {1, 2, 3, 2, 1, 6, 7, 8, 2, 2, 2, 3, 3, 2, 15, 16};
    char   x5[16] = {1, 1, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};

    printf("\n15...0      =   15  14  13  12    11  10  9   8     7   6   5   4     3   2   1   0\n\n");

    __m128i  x_vec  = _mm_loadu_si128((__m128i*)x1);

    printf("x_vec       = ");print_vec_char(x_vec        );

    __m128i  y      = mode_statistic (x_vec);

    printf("answer      = ");print_vec_char(y         );


    #if throughput_test == 1
    __m128i  x_vec1  = _mm_loadu_si128((__m128i*)x1);
    __m128i  x_vec2  = _mm_loadu_si128((__m128i*)x2);
    __m128i  x_vec3  = _mm_loadu_si128((__m128i*)x3);
    __m128i  x_vec4  = _mm_loadu_si128((__m128i*)x4);
    __m128i  x_vec5  = _mm_loadu_si128((__m128i*)x5);
    __m128i  y1, y2, y3, y4, y5;
    __asm__ __volatile__ ( "vzeroupper" : : : );   /* Remove this line on non-AVX processors */
    for (i=0;i<100000000;i++){
        y1       = mode_statistic (x_vec1);
        y2       = mode_statistic (x_vec2);
        y3       = mode_statistic (x_vec3);
        y4       = mode_statistic (x_vec4);
        y5       = mode_statistic (x_vec5);
        x_vec1   = mode_statistic (y1    );
        x_vec2   = mode_statistic (y2    );
        x_vec3   = mode_statistic (y3    );
        x_vec4   = mode_statistic (y4    );
        x_vec5   = mode_statistic (y5    );
     }
    printf("mask mode   = ");print_vec_char(y1           );
    printf("mask mode   = ");print_vec_char(y2           );
    printf("mask mode   = ");print_vec_char(y3           );
    printf("mask mode   = ");print_vec_char(y4           );
    printf("mask mode   = ");print_vec_char(y5           );
    #endif

    return 0;
}



int print_vec_char(__m128i x){
    char v[16];
    _mm_storeu_si128((__m128i *)v,x);
    printf("%3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi | %3hhi %3hhi %3hhi %3hhi\n",
           v[15],v[14],v[13],v[12],v[11],v[10],v[9],v[8],v[7],v[6],v[5],v[4],v[3],v[2],v[1],v[0]);
    return 0;
}

Sortie:

15...0      =   15  14  13  12    11  10  9   8     7   6   5   4     3   2   1   0

x_vec       =   5   5   3   4 |   5   2   9   3 |   7   7   4  21 |   7   2   2   5
sum         =  -4  -4  -2  -2 |  -4  -3  -1  -2 |  -3  -3  -2  -1 |  -3  -3  -3  -4
min_brc     =  -4  -4  -4  -4 |  -4  -4  -4  -4 |  -4  -4  -4  -4 |  -4  -4  -4  -4
mask        =  -1  -1   0   0 |  -1   0   0   0 |   0   0   0   0 |   0   0   0  -1
answer      =   5   5   5   5 |   5   5   5   5 |   5   5   5   5 |   5   5   5   5

Le minimum horizontal est calculé avec la méthode d' Evgeny Kluev.


Pour la comparaison des performances avec le code scalaire. Non-vectorisé sur la partie principale mais vectorisé sur l'initialisation table-clear et tmp. (168 cycles par f () appel pour fx8150 (22M appels complets en 1.0002 secondes à 3,7 GHz))

#include <x86intrin.h>

unsigned char tmp[16]; // extracted values are here (single instruction, store_ps)
unsigned char table[256]; // counter table containing zeroes
char f(__m128i values)
{
    _mm_store_si128((__m128i *)tmp,values);
    int maxOccurence=0;
    int currentValue=0;
    for(int i=0;i<16;i++)
    {
        unsigned char ind=tmp[i];
        unsigned char t=table[ind];
        t++;
        if(t>maxOccurence)
        {
             maxOccurence=t;
             currentValue=ind;
        }
        table[ind]=t;
    }
    for(int i=0;i<256;i++)
        table[i]=0;
    return currentValue;
}

g ++ 6.3 sortie:

f:                                      # @f
        movaps  %xmm0, tmp(%rip)
        movaps  %xmm0, -24(%rsp)
        xorl    %r8d, %r8d
        movq    $-15, %rdx
        movb    -24(%rsp), %sil
        xorl    %eax, %eax
        jmp     .LBB0_1
.LBB0_2:                                # %._crit_edge
        cmpl    %r8d, %esi
        cmovgel %esi, %r8d
        movb    tmp+16(%rdx), %sil
        incq    %rdx
.LBB0_1:                                # =>This Inner Loop Header: Depth=1
        movzbl  %sil, %edi
        movb    table(%rdi), %cl
        incb    %cl
        movzbl  %cl, %esi
        cmpl    %r8d, %esi
        cmovgl  %edi, %eax
        movb    %sil, table(%rdi)
        testq   %rdx, %rdx
        jne     .LBB0_2
        xorps   %xmm0, %xmm0
        movaps  %xmm0, table+240(%rip)
        movaps  %xmm0, table+224(%rip)
        movaps  %xmm0, table+208(%rip)
        movaps  %xmm0, table+192(%rip)
        movaps  %xmm0, table+176(%rip)
        movaps  %xmm0, table+160(%rip)
        movaps  %xmm0, table+144(%rip)
        movaps  %xmm0, table+128(%rip)
        movaps  %xmm0, table+112(%rip)
        movaps  %xmm0, table+96(%rip)
        movaps  %xmm0, table+80(%rip)
        movaps  %xmm0, table+64(%rip)
        movaps  %xmm0, table+48(%rip)
        movaps  %xmm0, table+32(%rip)
        movaps  %xmm0, table+16(%rip)
        movaps  %xmm0, table(%rip)
        movsbl  %al, %eax
        ret




sse