c फास्ट AVX512 modulo जब एक ही भाजक



performance optimization (1)

मैंने संभावित फैक्टरियल प्राइम्स (फॉर्म एन - + - 1 की संख्या) के लिए डिवाइडर खोजने की कोशिश की है और क्योंकि मैंने हाल ही में स्काईलेक-एक्स वर्कस्टेशन खरीदा है मुझे लगता है कि मुझे एवीएक्स 512 निर्देशों का उपयोग करके कुछ गति मिल सकती है।

एल्गोरिदम सरल है और मुख्य कदम एक ही विभाजक के लिए बार-बार modulo लेना है। मुख्य बात एन मान की बड़ी रेंज पर लूप करना है। यहाँ भोली दृष्टिकोण सी में लिखा गया है (P primes की तालिका):

uint64_t factorial_naive(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i, residue;
for (i = 0; i < APP_BUFLEN; i++){
    residue = 2;
    for (n=3; n <= nmax; n++){
        residue *=  n;
        residue %= P[i];
        // Lets check if we found factor
        if (nmin <= n){
            if( residue == 1){
                report_factor(n, -1, P[i]);
            }
            if(residue == P[i]- 1){
                report_factor(n, 1, P[i]);
            }
        }
    }
}
return EXIT_SUCCESS;
}

यहाँ विचार एन की एक बड़ी श्रृंखला की जाँच करने के लिए है, उदाहरण के लिए 1,000,000 -> 10,000,000 समान भाजक के खिलाफ। तो हम कई लाख बार एक ही विभाजक के लिए modulo सम्मान ले जाएगा। DIV का उपयोग करना बहुत धीमा है इसलिए गणना की सीमा के आधार पर कई संभावित दृष्टिकोण हैं। यहाँ मेरे मामले में n सबसे अधिक संभावना 10 ^ 7 से कम है और संभावित भाजक p 10,000 G (<10 ^ 13) से कम है, इसलिए संख्या 64-बिट से कम और 53-बिट से भी कम है!, लेकिन उत्पाद अधिकतम अवशेष (पी -1) बार n 64-बिट्स से बड़ा है। इसलिए मैंने सोचा कि मॉन्टगोमरी पद्धति का सबसे सरल संस्करण काम नहीं करता है क्योंकि हम संख्या से मोडुलो ले रहे हैं जो 64-बिट से बड़ा है।

मुझे पॉवर पीसी के लिए कुछ पुराने कोड मिले जहां FMA का उपयोग डबल्स का उपयोग करते समय 106 बिट्स (मुझे लगता है) तक एक सटीक उत्पाद प्राप्त करने के लिए किया गया था। इसलिए मैंने इस दृष्टिकोण को AVX 512 असेंबलर (Intel Intrinsics) में बदल दिया। यहाँ FMA विधि का एक सरल संस्करण है, यह Dekker (1971) के काम पर आधारित है, Dekker उत्पाद और TwoProduct का FMA संस्करण इसके पीछे उपयोगी शब्द हैं जब इसके पीछे तर्क खोजने / खोजने का प्रयास किया जाता है। इस दृष्टिकोण को भी इस मंच में चर्चा की गई है (जैसे here )।

int64_t factorial_FMA(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i;
double prime_double, prime_double_reciprocal, quotient, residue;
double nr, n_double, prime_times_quotient_high, prime_times_quotient_low;

for (i = 0; i < APP_BUFLEN; i++){
    residue = 2.0;
    prime_double = (double)P[i];
    prime_double_reciprocal = 1.0 / prime_double;
    n_double = 3.0;
    for (n=3; n <= nmax; n++){
        nr =  n_double * residue;
        quotient = fma(nr, prime_double_reciprocal, rounding_constant);
        quotient -= rounding_constant;
        prime_times_quotient_high= prime_double * quotient;
        prime_times_quotient_low = fma(prime_double, quotient, -prime_times_quotient_high);
        residue = fma(residue, n, -prime_times_quotient_high) - prime_times_quotient_low;

        if (residue < 0.0) residue += prime_double;
        n_double += 1.0;

        // Lets check if we found factor
        if (nmin <= n){
            if( residue == 1.0){
                report_factor(n, -1, P[i]);
            }
            if(residue == prime_double - 1.0){
                report_factor(n, 1, P[i]);
            }
        }
    }
}
return EXIT_SUCCESS;
}

यहां मैंने जादू का उपयोग किया है

static const double rounding_constant = 6755399441055744.0; 

युगल के लिए 2 ^ 51 + 2 ^ 52 मैजिक नंबर है।

मैंने इसे AVX512 (32 संभावित विभाजक प्रति लूप) में बदल दिया और IACA का उपयोग करके परिणाम का विश्लेषण किया। इसने बताया कि विवादास्पद अड़चन: अनुपलब्ध आवंटन संसाधनों के कारण बैकएंड और बैकेंड आवंटन को रोक दिया गया था। मैं असेंबलर के साथ बहुत अनुभवी नहीं हूं, इसलिए मेरा सवाल यह है कि क्या मैं इस गति को बढ़ाने और इस बैकेंड की अड़चन को हल करने के लिए कुछ कर सकता हूं?

AVX512 कोड यहां है और इसे github से भी पाया जा सकता है

uint64_t factorial_AVX512_unrolled_four(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
// we are trying to find a factor for a factorial numbers : n! +-1
//nmin is minimum n we want to report and nmax is maximum. P is table of primes
// we process 32 primes in one loop.
// naive version of the algorithm is int he function factorial_naive
// and simple version of the FMA based approach in the function factorial_simpleFMA

const double one_table[8] __attribute__ ((aligned(64))) ={1.0, 1.0, 1.0,1.0,1.0,1.0,1.0,1.0};


uint64_t n;

__m512d zero, rounding_const, one, n_double;

__m512i prime1, prime2, prime3, prime4;

__m512d residue1, residue2, residue3, residue4;
__m512d prime_double_reciprocal1, prime_double_reciprocal2, prime_double_reciprocal3, prime_double_reciprocal4;
__m512d quotient1, quotient2, quotient3, quotient4;
__m512d prime_times_quotient_high1, prime_times_quotient_high2, prime_times_quotient_high3, prime_times_quotient_high4;
__m512d prime_times_quotient_low1, prime_times_quotient_low2, prime_times_quotient_low3, prime_times_quotient_low4;
__m512d nr1, nr2, nr3, nr4;
__m512d prime_double1, prime_double2, prime_double3, prime_double4;
__m512d prime_minus_one1, prime_minus_one2, prime_minus_one3, prime_minus_one4;

__mmask8 negative_reminder_mask1, negative_reminder_mask2, negative_reminder_mask3, negative_reminder_mask4;
__mmask8 found_factor_mask11, found_factor_mask12, found_factor_mask13, found_factor_mask14;
__mmask8 found_factor_mask21, found_factor_mask22, found_factor_mask23, found_factor_mask24;

// load data and initialize cariables for loop
rounding_const = _mm512_set1_pd(rounding_constant);
one = _mm512_load_pd(one_table);
zero = _mm512_setzero_pd ();

// load primes used to sieve
prime1 = _mm512_load_epi64((__m512i *) &P[0]);
prime2 = _mm512_load_epi64((__m512i *) &P[8]);
prime3 = _mm512_load_epi64((__m512i *) &P[16]);
prime4 = _mm512_load_epi64((__m512i *) &P[24]);

// convert primes to double
prime_double1 = _mm512_cvtepi64_pd (prime1); // vcvtqq2pd
prime_double2 = _mm512_cvtepi64_pd (prime2); // vcvtqq2pd
prime_double3 = _mm512_cvtepi64_pd (prime3); // vcvtqq2pd
prime_double4 = _mm512_cvtepi64_pd (prime4); // vcvtqq2pd

// calculates 1.0/ prime
prime_double_reciprocal1 = _mm512_div_pd(one, prime_double1);
prime_double_reciprocal2 = _mm512_div_pd(one, prime_double2);
prime_double_reciprocal3 = _mm512_div_pd(one, prime_double3);
prime_double_reciprocal4 = _mm512_div_pd(one, prime_double4);

// for comparison if we have found factors for n!+1
prime_minus_one1 = _mm512_sub_pd(prime_double1, one);
prime_minus_one2 = _mm512_sub_pd(prime_double2, one);
prime_minus_one3 = _mm512_sub_pd(prime_double3, one);
prime_minus_one4 = _mm512_sub_pd(prime_double4, one);

// residue init
residue1 =  _mm512_set1_pd(2.0);
residue2 =  _mm512_set1_pd(2.0);
residue3 =  _mm512_set1_pd(2.0);
residue4 =  _mm512_set1_pd(2.0);

// double counter init
n_double = _mm512_set1_pd(3.0);

// main loop starts here. typical value for nmax can be 5,000,000 -> 10,000,000

for (n=3; n<=nmax; n++) // main loop
{

    // timings for instructions:
    // _mm512_load_epi64 = vmovdqa64 : L 1, T 0.5
    // _mm512_load_pd = vmovapd : L 1, T 0.5
    // _mm512_set1_pd
    // _mm512_div_pd = vdivpd : L 23, T 16
    // _mm512_cvtepi64_pd = vcvtqq2pd : L 4, T 0,5

    // _mm512_mul_pd = vmulpd :  L 4, T 0.5
    // _mm512_fmadd_pd = vfmadd132pd, vfmadd213pd, vfmadd231pd :  L 4, T 0.5
    // _mm512_fmsub_pd = vfmsub132pd, vfmsub213pd, vfmsub231pd : L 4, T 0.5
    // _mm512_sub_pd = vsubpd : L 4, T 0.5
    // _mm512_cmplt_pd_mask = vcmppd : L ?, Y 1
    // _mm512_mask_add_pd = vaddpd :  L 4, T 0.5
    // _mm512_cmpeq_pd_mask = vcmppd L ?, Y 1
    // _mm512_kor = korw L 1, T 1

    // nr = residue *  n
    nr1 = _mm512_mul_pd (residue1, n_double);
    nr2 = _mm512_mul_pd (residue2, n_double);
    nr3 = _mm512_mul_pd (residue3, n_double);
    nr4 = _mm512_mul_pd (residue4, n_double);

    // quotient = nr * 1.0/ prime_double + rounding_constant
    quotient1 = _mm512_fmadd_pd(nr1, prime_double_reciprocal1, rounding_const);
    quotient2 = _mm512_fmadd_pd(nr2, prime_double_reciprocal2, rounding_const);
    quotient3 = _mm512_fmadd_pd(nr3, prime_double_reciprocal3, rounding_const);
    quotient4 = _mm512_fmadd_pd(nr4, prime_double_reciprocal4, rounding_const);

    // quotient -= rounding_constant, now quotient is rounded to integer
    // countient should be at maximum nmax (10,000,000)
    quotient1 = _mm512_sub_pd(quotient1, rounding_const);
    quotient2 = _mm512_sub_pd(quotient2, rounding_const);
    quotient3 = _mm512_sub_pd(quotient3, rounding_const);
    quotient4 = _mm512_sub_pd(quotient4, rounding_const);

    // now we calculate high and low for prime * quotient using decker product (FMA).
    // quotient is calculated using approximation but this is accurate for given quotient
    prime_times_quotient_high1 = _mm512_mul_pd(quotient1, prime_double1);
    prime_times_quotient_high2 = _mm512_mul_pd(quotient2, prime_double2);
    prime_times_quotient_high3 = _mm512_mul_pd(quotient3, prime_double3);
    prime_times_quotient_high4 = _mm512_mul_pd(quotient4, prime_double4);


    prime_times_quotient_low1 = _mm512_fmsub_pd(quotient1, prime_double1, prime_times_quotient_high1);
    prime_times_quotient_low2 = _mm512_fmsub_pd(quotient2, prime_double2, prime_times_quotient_high2);
    prime_times_quotient_low3 = _mm512_fmsub_pd(quotient3, prime_double3, prime_times_quotient_high3);
    prime_times_quotient_low4 = _mm512_fmsub_pd(quotient4, prime_double4, prime_times_quotient_high4);

    // now we calculate new reminder using decker product and using original values
    // we subtract above calculated prime * quotient (quotient is aproximation)

    residue1 = _mm512_fmsub_pd(residue1, n_double, prime_times_quotient_high1);
    residue2 = _mm512_fmsub_pd(residue2, n_double, prime_times_quotient_high2);
    residue3 = _mm512_fmsub_pd(residue3, n_double, prime_times_quotient_high3);
    residue4 = _mm512_fmsub_pd(residue4, n_double, prime_times_quotient_high4);

    residue1 = _mm512_sub_pd(residue1, prime_times_quotient_low1);
    residue2 = _mm512_sub_pd(residue2, prime_times_quotient_low2);
    residue3 = _mm512_sub_pd(residue3, prime_times_quotient_low3);
    residue4 = _mm512_sub_pd(residue4, prime_times_quotient_low4);

    // lets check if reminder < 0
    negative_reminder_mask1 = _mm512_cmplt_pd_mask(residue1,zero);
    negative_reminder_mask2 = _mm512_cmplt_pd_mask(residue2,zero);
    negative_reminder_mask3 = _mm512_cmplt_pd_mask(residue3,zero);
    negative_reminder_mask4 = _mm512_cmplt_pd_mask(residue4,zero);

    // we and prime back to reminder using mask if it was < 0
    residue1 = _mm512_mask_add_pd(residue1, negative_reminder_mask1, residue1, prime_double1);
    residue2 = _mm512_mask_add_pd(residue2, negative_reminder_mask2, residue2, prime_double2);
    residue3 = _mm512_mask_add_pd(residue3, negative_reminder_mask3, residue3, prime_double3);
    residue4 = _mm512_mask_add_pd(residue4, negative_reminder_mask4, residue4, prime_double4);

    n_double = _mm512_add_pd(n_double,one);

    // if we are below nmin then we continue next iteration
    if (n < nmin) continue;

    // Lets check if we found any factors, residue 1 == n!-1
    found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
    found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
    found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
    found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);

    // residue prime -1  == n!+1
    found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
    found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
    found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
    found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);     

    if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
    found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)
    { // we find factor very rarely

        double *residual_list1 = (double *) &residue1;
        double *residual_list2 = (double *) &residue2;
        double *residual_list3 = (double *) &residue3;
        double *residual_list4 = (double *) &residue4;

        double *prime_list1 = (double *) &prime_double1;
        double *prime_list2 = (double *) &prime_double2;
        double *prime_list3 = (double *) &prime_double3;
        double *prime_list4 = (double *) &prime_double4;



        for (int i=0; i <8; i++){
            if( residual_list1[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list1[i]);
            }
            if( residual_list2[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list2[i]);
            }
            if( residual_list3[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list3[i]);
            }
            if( residual_list4[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list4[i]);
            }

            if(residual_list1[i] == (prime_list1[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list1[i]);
            }
            if(residual_list2[i] == (prime_list2[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list2[i]);
            }
            if(residual_list3[i] == (prime_list3[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list3[i]);
            }
            if(residual_list4[i] == (prime_list4[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list4[i]);
            }
        }
    }

}

return EXIT_SUCCESS;
}

जैसा कि कुछ टिप्पणीकारों ने सुझाव दिया है: एक "बैकएंड" अड़चन वह है जो आप इस कोड के लिए अपेक्षा करेंगे। यह सुझाव देता है कि आप चीजों को बहुत अच्छी तरह से खिला रहे हैं, जो आप चाहते हैं।

रिपोर्ट को देखते हुए, इस खंड में एक अवसर होना चाहिए:

    // Lets check if we found any factors, residue 1 == n!-1
    found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
    found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
    found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
    found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);

    // residue prime -1  == n!+1
    found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
    found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
    found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
    found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);     

    if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
    found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)

IACA विश्लेषण से:

|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r11d, k0
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw eax, k1
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw ecx, k2
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw esi, k3
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw edi, k4
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r8d, k5
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r9d, k6
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r10d, k7
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, eax
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, ecx
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, esi
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, edi
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, r8d
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, r9d
|   1*     |             |      |             |             |      |      |      |      | or r11d, r10d

प्रोसेसर परिणामी तुलना मास्क (k0-k7) को "या" ऑपरेशन के लिए नियमित रजिस्टरों पर ले जा रहा है। आपको उन चालों को खत्म करने में सक्षम होना चाहिए, और, 6ops बनाम 8 में "या" रोलअप करना चाहिए।

नोट: Found_factor_mask प्रकार को __mmask8 रूप में परिभाषित किया __mmask8 , जहां उन्हें __mask16 (एक 512bit क्षेत्र में 16x डबल फ़्लोट) होना चाहिए। यह कंपाइलर को कुछ अनुकूलन में मिल सकता है। यदि नहीं, तो एक टिप्पणीकार के रूप में विधानसभा को छोड़ दें।

और संबंधित: iteractions के किस अंश में यह या मास्क क्लॉज आग लगाता है? जैसा कि एक अन्य टिप्पणीकार ने देखा है, आपको इसे एक संकलित "या" ऑपरेशन के साथ नियंत्रित करने में सक्षम होना चाहिए। प्रत्येक अनियंत्रित चलना (या एन पुनरावृत्तियों के बाद) के अंत में संचित "या" मान की जांच करें, और यदि यह "सत्य" है, तो वापस जाएं और मानों का पता लगाने के लिए पुन: करें कि किस मान ने इसे ट्रिगर किया।

(और, आप मिलान एन मान को खोजने के लिए "रोल" के भीतर बाइनरी खोज कर सकते हैं - जिससे कुछ लाभ मिल सकता है)।

अगला, आपको इस मिड-लूप चेक से छुटकारा पाने में सक्षम होना चाहिए:

    // if we are below nmin then we continue next iteration, we
    if (n < nmin) continue;

जो यहाँ दिखाता है:

|   1*     |             |      |             |             |      |      |      |      | cmp r14, 0x3e8
|   0*F    |             |      |             |             |      |      |      |      | jb 0x229

हो सकता है कि भविष्यवक्ता (संभवतः) को यह एक (अधिकांशतः) सही मिले, लेकिन यह बहुत बड़ा लाभ नहीं हो सकता, लेकिन आपको दो "चरणों" के लिए दो अलग-अलग छोरों के होने से कुछ लाभ प्राप्त करने चाहिए

  • n = 3 से n = nmin-1
  • n = nmin और परे

यहां तक ​​कि अगर आप एक चक्र हासिल करते हैं, तो यह 3% है। और चूंकि यह आमतौर पर बड़े 'या' ऑपरेशन से संबंधित है, इसलिए, वहाँ पाया जा करने के लिए और अधिक होशियारी हो सकती है।





avx512