performance optimization - 為什麼這個C++代碼比測試Collat z猜想的手寫組件更快?




optimized pdf (10)

我在彙編和C ++中為Project Euler Q14編寫了這兩個解決方案。 它們是測試Collat​​z猜想的完全相同的蠻力方法。 該組裝解決方案與組裝

nasm -felf64 p14.asm && gcc p14.o -o p14

C ++編譯時使用

g++ p14.cpp -o p14

大會, p14.asm

section .data
    fmt db "%d", 10, 0

global main
extern printf

section .text

main:
    mov rcx, 1000000
    xor rdi, rdi        ; max i
    xor rsi, rsi        ; i

l1:
    dec rcx
    xor r10, r10        ; count
    mov rax, rcx

l2:
    test rax, 1
    jpe even

    mov rbx, 3
    mul rbx
    inc rax
    jmp c1

even:
    mov rbx, 2
    xor rdx, rdx
    div rbx

c1:
    inc r10
    cmp rax, 1
    jne l2

    cmp rdi, r10
    cmovl rdi, r10
    cmovl rsi, rcx

    cmp rcx, 2
    jne l1

    mov rdi, fmt
    xor rax, rax
    call printf
    ret

C ++,p14.cpp

#include <iostream>

using namespace std;

int sequence(long n) {
    int count = 1;
    while (n != 1) {
        if (n % 2 == 0)
            n /= 2;
        else
            n = n*3 + 1;

        ++count;
    }

    return count;
}

int main() {
    int max = 0, maxi;
    for (int i = 999999; i > 0; --i) {
        int s = sequence(i);
        if (s > max) {
            max = s;
            maxi = i;
        }
    }

    cout << maxi << endl;
}

我知道編譯器優化以提高速度和一切,但我沒有看到很多方法來進一步優化我的彙編解決方案(以編程方式不以數學方式說)。

C ++代碼每個術語和每個階段都有模數,每個階段的彙編只有一個分區。

但是該組件平均比C ++解決方案平均長1秒。 為什麼是這樣? 主要是好奇心。

執行時間

我的系統:1.4 GHz Intel Celeron 2955U(Haswell微架構)上的64位Linux。

  • g++ (未優化):平均1272毫秒

  • g++ -O3 avg 578 ms

  • 原始asm(div)平均2650毫秒

  • Asm (shr) avg 679 ms

  • @johnfound asm ,與nasm avg 501 ms組裝

  • @hidefromkgb asm avg 200 ms

  • @hidefromkgb asm由@Peter Cordes avg 145 ms 優化

  • @Veedrac C ++平均值為81毫秒, -O3 305毫秒, -O0平均值


Answers

聲稱C ++編譯器可以產生比合格彙編語言程序員更優化的代碼是一個非常糟糕的錯誤。 特別是在這種情況下。 人們總是可以使代碼更好,編譯器可以,而這種特殊情況就是對這種說法的很好說明。

你看到的時間差異是因為問題中的彙編代碼在內部循環中非常不理想。

(以下代碼是32位,但可以輕鬆轉換為64位)

例如,序列函數可以優化為只有5條指令:

    .seq:
        inc     esi                 ; counter
        lea     edx, [3*eax+1]      ; edx = 3*n+1
        shr     eax, 1              ; eax = n/2
        cmovc   eax, edx            ; if CF eax = edx
        jnz     .seq                ; jmp if n<>1

整個代碼如下所示:

include "%lib%/freshlib.inc"
@BinaryType console, compact
options.DebugMode = 1
include "%lib%/freshlib.asm"

start:
        InitializeAll
        mov ecx, 999999
        xor edi, edi        ; max
        xor ebx, ebx        ; max i

    .main_loop:

        xor     esi, esi
        mov     eax, ecx

    .seq:
        inc     esi                 ; counter
        lea     edx, [3*eax+1]      ; edx = 3*n+1
        shr     eax, 1              ; eax = n/2
        cmovc   eax, edx            ; if CF eax = edx
        jnz     .seq                ; jmp if n<>1

        cmp     edi, esi
        cmovb   edi, esi
        cmovb   ebx, ecx

        dec     ecx
        jnz     .main_loop

        OutputValue "Max sequence: ", edi, 10, -1
        OutputValue "Max index: ", ebx, 10, -1

        FinalizeAll
        stdcall TerminateAll, 0

為了編譯這段代碼,需要FreshLib

在我的測試中(1 GHz AMD A4-1200處理器),上述代碼比問題中的C ++代碼快約四倍(用-O0編譯時:430毫秒比1900毫秒),速度提高了兩倍以上(430毫秒比830毫秒),當C ++代碼用-O3編譯時。

兩個程序的輸出相同:i = 837799上的最大序列= 525。


You did not post the code generated by the compiler, so there' some guesswork here, but even without having seen it, one can say that this:

test rax, 1
jpe even

... has a 50% chance of mispredicting the branch, and that will come expensive.

The compiler almost certainly does both computations (which costs neglegibly more since the div/mod is quite long latency, so the multiply-add is "free") and follows up with a CMOV. Which, of course, has a zero percent chance of being mispredicted.


C++ programs are translated to assembly programs during the generation of machine code from the source code. It would be virtually wrong to say assembly is slower than C++. Moreover, the binary code generated differs from compiler to compiler. So a smart C++ compiler may produce binary code more optimal and efficient than a dumb assembler's code.

However I believe your profiling methodology has certain flaws. The following are general guidelines for profiling:

  1. Make sure your system is in its normal/idle state. Stop all running processes (applications) that you started or that use CPU intensively (or poll over the network).
  2. Your datasize must be greater in size.
  3. Your test must run for something more than 5-10 seconds.
  4. Do not rely on just one sample. Perform your test N times. Collect results and calculate the mean or median of the result.

Even without looking at assembly, the most obvious reason is that /= 2 is probably optimized as >>=1 and many processors have a very quick shift operation. But even if a processor doesn't have a shift operation, the integer division is faster than floating point division.

Edit: your milage may vary on the "integer division is faster than floating point division" statement above. The comments below reveal that the modern processors have prioritized optimizing fp division over integer division. So if someone were looking for the most likely reason for the speedup which this thread's question asks about, then compiler optimizing /=2 as >>=1 would be the best 1st place to look.

On an unrelated note , if n is odd, the expression n*3+1 will always be even. So there is no need to check. You can change that branch to

{
   n = (n*3+1) >> 1;
   count += 2;
}

So the whole statement would then be

if (n & 1)
{
    n = (n*3 + 1) >> 1;
    count += 2;
}
else
{
    n >>= 1;
    ++count;
}

For the Collatz problem, you can get a significant boost in performance by caching the "tails". This is a time/memory trade-off. See: memoization ( https://en.wikipedia.org/wiki/Memoization ). You could also look into dynamic programming solutions for other time/memory trade-offs.

Example python implementation:

import sys

inner_loop = 0

def collatz_sequence(N, cache):
    global inner_loop

    l = [ ]
    stop = False
    n = N

    tails = [ ]

    while not stop:
        inner_loop += 1
        tmp = n
        l.append(n)
        if n <= 1:
            stop = True  
        elif n in cache:
            stop = True
        elif n % 2:
            n = 3*n + 1
        else:
            n = n // 2
        tails.append((tmp, len(l)))

    for key, offset in tails:
        if not key in cache:
            cache[key] = l[offset:]

    return l

def gen_sequence(l, cache):
    for elem in l:
        yield elem
        if elem in cache:
            yield from gen_sequence(cache[elem], cache)
            raise StopIteration

if __name__ == "__main__":
    le_cache = {}

    for n in range(1, 4711, 5):
        l = collatz_sequence(n, le_cache)
        print("{}: {}".format(n, len(list(gen_sequence(l, le_cache)))))

    print("inner_loop = {}".format(inner_loop))

From comments:

But, this code never stops (because of integer overflow) !?! Yves Daoust

For many numbers it will not overflow.

If it will overflow - for one of those unlucky initial seeds, the overflown number will very likely converge toward 1 without another overflow.

Still this poses interesting question, is there some overflow-cyclic seed number?

Any simple final converging series starts with power of two value (obvious enough?).

2^64 will overflow to zero, which is undefined infinite loop according to algorithm (ends only with 1), but the most optimal solution in answer will finish due to shr rax producing ZF=1.

Can we produce 2^64? If the starting number is 0x5555555555555555 , it's odd number, next number is then 3n+1, which is 0xFFFFFFFFFFFFFFFF + 1 = 0 . Theoretically in undefined state of algorithm, but the optimized answer of johnfound will recover by exiting on ZF=1. The cmp rax,1 of Peter Cordes will end in infinite loop (QED variant 1, "cheapo" through undefined 0 number).

How about some more complex number, which will create cycle without 0 ? Frankly, I'm not sure, my Math theory is too hazy to get any serious idea, how to deal with it in serious way. But intuitively I would say the series will converge to 1 for every number : 0 < number, as the 3n+1 formula will slowly turn every non-2 prime factor of original number (or intermediate) into some power of 2, sooner or later. So we don't need to worry about infinite loop for original series, only overflow can hamper us.

So I just put few numbers into sheet and took a look on 8 bit truncated numbers.

There are three values overflowing to 0 : 227 , 170 and 85 ( 85 going directly to 0 , other two progressing toward 85 ).

But there's no value creating cyclic overflow seed.

Funnily enough I did a check, which is the first number to suffer from 8 bit truncation, and already 27 is affected! It does reach value 9232 in proper non-truncated series (first truncated value is 322 in 12th step), and the maximum value reached for any of the 2-255 input numbers in non-truncated way is 13120 (for the 255 itself), maximum number of steps to converge to 1 is about 128 (+-2, not sure if "1" is to count, etc...).

Interestingly enough (for me) the number 9232 is maximum for many other source numbers, what's so special about it? :-O 9232 = 0x2410 ... hmmm.. no idea.

Unfortunately I can't get any deep grasp of this series, why does it converge and what are the implications of truncating them to k bits, but with cmp number,1 terminating condition it's certainly possible to put the algorithm into infinite loop with particular input value ending as 0 after truncation.

But the value 27 overflowing for 8 bit case is sort of alerting, this looks like if you count the number of steps to reach value 1 , you will get wrong result for majority of numbers from the total k-bit set of integers. For the 8 bit integers the 146 numbers out of 256 have affected series by truncation (some of them may still hit the correct number of steps by accident maybe, I'm too lazy to check).


更多的表現:一個簡單的變化是觀察到n = 3n + 1之後,n將是偶數,所以你可以立即除以2。 而n不會是1,所以你不需要測試它。 所以你可以節省一些如果陳述和寫:

while (n % 2 == 0) n /= 2;
if (n > 1) for (;;) {
    n = (3*n + 1) / 2;
    if (n % 2 == 0) {
        do n /= 2; while (n % 2 == 0);
        if (n == 1) break;
    }
}

這是一個巨大的勝利:如果你看看n的最低8位,所有步驟直到你除以2 8次完全由這8位決定。 例如,如果最後八位是0x01,那是二進制數,你的編號是? 0000 0001然後接下來的步驟是:

3n+1 -> ???? 0000 0100
/ 2  -> ???? ?000 0010
/ 2  -> ???? ??00 0001
3n+1 -> ???? ??00 0100
/ 2  -> ???? ???0 0010
/ 2  -> ???? ???? 0001
3n+1 -> ???? ???? 0100
/ 2  -> ???? ???? ?010
/ 2  -> ???? ???? ??01
3n+1 -> ???? ???? ??00
/ 2  -> ???? ???? ???0
/ 2  -> ???? ???? ????

因此,所有這些步驟都可以預測,並且256k + 1被替換為81k + 1.所有組合都會發生類似的情況。 所以你可以用一個大的switch語句做一個循環:

k = n / 256;
m = n % 256;

switch (m) {
    case 0: n = 1 * k + 0; break;
    case 1: n = 81 * k + 1; break; 
    case 2: n = 81 * k + 1; break; 
    ...
    case 155: n = 729 * k + 425; break;
    ...
}

運行循環直到n≤128,因為那時n可能會變成1,少於八個分數減2,並且一次執行八步或更多步驟會讓你錯過第一次達到1的點。 然後繼續“正常”循環 - 或準備一個表格,告訴你需要多少步才能達到1。

PS。 我強烈懷疑彼得科爾德的建議會讓它更快。 除了一個以外,將不會有任何條件分支,除非循環實際結束,否則一個分支將被正確預測。 所以代碼會是這樣的

static const unsigned int multipliers [256] = { ... }
static const unsigned int adders [256] = { ... }

while (n > 128) {
    size_t lastBits = n % 256;
    n = (n >> 8) * multipliers [lastBits] + adders [lastBits];
}

在實踐中,你會測量一次處理n的最後9,10,11,12位是否會更快。 對於每一位,表中的條目數量都會增加一倍,而且當表格不適合L1緩存時,我會緩慢放慢速度。

PPS。 如果你需要的操作次數:在每次迭代中,我們都會按正確的八個分區和可變數目的(3n + 1)操作,因此一個明顯的方法來計算操作將是另一個數組。 但我們實際上可以計算步數(根據循環的迭代次數)。

我們可以稍微重新定義問題:如果奇數將n替換為(3n + 1)/ 2,並且如果偶數則將n替換為n / 2。 然後每個迭代都會執行8個步驟,但是您可以考慮作弊:-)因此,假定有r個運算n <-3n + 1和s運算n <-n / 2。 由於n < - 3n + 1意味著n < - 3n *(1 + 1 / 3n),所以結果將完全是n'= n * 3 ^ r / 2 ^ s。 取對數,我們發現r =(s + log2(n'/ n))/ log2(3)。

如果我們執行循環,直到n≤1,000,000,並且有一個預先計算的表,從任何起始點n≤1,000,000需要多少次迭代,那麼如上所述計算r(四捨五入到最接近的整數)將會給出正確的結果,除非s確實很大。


On a rather unrelated note: more performance hacks!

  • [the first «conjecture» has been finally debunked by @ShreevatsaR; removed]

  • When traversing the sequence, we can only get 3 possible cases in the 2-neighborhood of the current element N (shown first):

    1. [even] [odd]
    2. [odd] [even]
    3. [even] [even]

    To leap past these 2 elements means to compute (N >> 1) + N + 1 , ((N << 1) + N + 1) >> 1 and N >> 2 , respectively.

    Let`s prove that for both cases (1) and (2) it is possible to use the first formula, (N >> 1) + N + 1 .

    Case (1) is obvious. Case (2) implies (N & 1) == 1 , so if we assume (without loss of generality) that N is 2-bit long and its bits are ba from most- to least-significant, then a = 1 , and the following holds:

    (N << 1) + N + 1:     (N >> 1) + N + 1:
    
            b10                    b1
             b1                     b
           +  1                   + 1
           ----                   ---
           bBb0                   bBb
    

    where B = !b . Right-shifting the first result gives us exactly what we want.

    QED: (N & 1) == 1 ⇒ (N >> 1) + N + 1 == ((N << 1) + N + 1) >> 1 .

    As proven, we can traverse the sequence 2 elements at a time, using a single ternary operation. Another 2× time reduction.

The resulting algorithm looks like this:

uint64_t sequence(uint64_t size, uint64_t *path) {
    uint64_t n, i, c, maxi = 0, maxc = 0;

    for (n = i = (size - 1) | 1; i > 2; n = i -= 2) {
        c = 2;
        while ((n = ((n & 3)? (n >> 1) + n + 1 : (n >> 2))) > 2)
            c += 2;
        if (n == 2)
            c++;
        if (c > maxc) {
            maxi = i;
            maxc = c;
        }
    }
    *path = maxc;
    return maxi;
}

int main() {
    uint64_t maxi, maxc;

    maxi = sequence(1000000, &maxc);
    printf("%llu, %llu\n", maxi, maxc);
    return 0;
}

Here we compare n > 2 because the process may stop at 2 instead of 1 if the total length of the sequence is odd.

[EDIT:]

Let`s translate this into assembly!

MOV RCX, 1000000;



DEC RCX;
AND RCX, -2;
XOR RAX, RAX;
MOV RBX, RAX;

@main:
  XOR RSI, RSI;
  LEA RDI, [RCX + 1];

  @loop:
    ADD RSI, 2;
    LEA RDX, [RDI + RDI*2 + 2];
    SHR RDX, 1;
    SHRD RDI, RDI, 2;    ror rdi,2   would do the same thing
    CMOVL RDI, RDX;      Note that SHRD leaves OF = undefined with count>1, and this doesn't work on all CPUs.
    CMOVS RDI, RDX;
    CMP RDI, 2;
  JA @loop;

  LEA RDX, [RSI + 1];
  CMOVE RSI, RDX;

  CMP RAX, RSI;
  CMOVB RAX, RSI;
  CMOVB RBX, RCX;

  SUB RCX, 2;
JA @main;



MOV RDI, RCX;
ADD RCX, 10;
PUSH RDI;
PUSH RCX;

@itoa:
  XOR RDX, RDX;
  DIV RCX;
  ADD RDX, '0';
  PUSH RDX;
  TEST RAX, RAX;
JNE @itoa;

  PUSH RCX;
  LEA RAX, [RBX + 1];
  TEST RBX, RBX;
  MOV RBX, RDI;
JNE @itoa;

POP RCX;
INC RDI;
MOV RDX, RDI;

@outp:
  MOV RSI, RSP;
  MOV RAX, RDI;
  SYSCALL;
  POP RAX;
  TEST RAX, RAX;
JNE @outp;

LEA RAX, [RDI + 59];
DEC RDI;
SYSCALL;

Use these commands to compile:

nasm -f elf64 file.asm
ld -o file file.o

See the C and an improved/bugfixed version of the asm by Peter Cordes on Godbolt . (editor's note: Sorry for putting my stuff in your answer, but my answer hit the 30k char limit from Godbolt links + text!)


The simple answer:

  • doing a MOV RBX, 3 and MUL RBX is expensive; just ADD RBX, RBX twice

  • ADD 1 is probably faster than INC here

  • MOV 2 and DIV is very expensive; just shift right

  • 64-bit code is usually noticeably slower than 32-bit code and the alignment issues are more complicated; with small programs like this you have to pack them so you are doing parallel computation to have any chance of being faster than 32-bit code

If you generate the assembly listing for your C++ program, you can see how it differs from your assembly.


在函數內部,字節碼是

  2           0 SETUP_LOOP              20 (to 23)
              3 LOAD_GLOBAL              0 (xrange)
              6 LOAD_CONST               3 (100000000)
              9 CALL_FUNCTION            1
             12 GET_ITER            
        >>   13 FOR_ITER                 6 (to 22)
             16 STORE_FAST               0 (i)

  3          19 JUMP_ABSOLUTE           13
        >>   22 POP_BLOCK           
        >>   23 LOAD_CONST               0 (None)
             26 RETURN_VALUE        

在頂層,字節碼是

  1           0 SETUP_LOOP              20 (to 23)
              3 LOAD_NAME                0 (xrange)
              6 LOAD_CONST               3 (100000000)
              9 CALL_FUNCTION            1
             12 GET_ITER            
        >>   13 FOR_ITER                 6 (to 22)
             16 STORE_NAME               1 (i)

  2          19 JUMP_ABSOLUTE           13
        >>   22 POP_BLOCK           
        >>   23 LOAD_CONST               2 (None)
             26 RETURN_VALUE        

不同之處在於STORE_FASTSTORE_NAME更快(!)。 這是因為在一個函數中, i是一個局部的,但是在頂層是一個全局的。

要檢查字節碼,請使用dis模塊 。 我能夠直接反彙編這個函數,但是為了反彙編我不得不使用compile內置的頂級代碼。





c++ performance assembly optimization x86