mht.wtf

Computer science, programming, and whatnot.

Code Generation and Merge Sort

April 24, 2019 back to posts

I was reading a few pages of Knuths The Art of Computer Programming, Volume4A about "branchless computation" (p. 180) in which he demonstrates how to get rid of branches by using conditional instructions. As an instructive example he consideres the inner part of merge sort, in which we are to merge two sorted lists of numbers into one bigger list of the numbers. The description as given by Knuth is as follows:

If $x_i < y_j$ set $z_k \gets x_i$, $i \gets i+1$, and go to x_done if $i = i_{max}$.
Otherwise set $z_k \gets y_i$, $j \gets j+1$, and go to y_done if $j = j_{max}$.
Then set $k \gets k+1$ and go to z_done if $k = k_{max}$.

$x$ and $y$ are the input lists, $z$ is the output merged list. $i$, $j$, and $k$ are loop indices for the three respective lists and the $_{max}$ variants are the lists length.

I got curious and decided to see how a standard optimizing compilier would handle this case, and whether writing the assmebly yourself would provide any gain in performance. After all, this is just slightly more complicated than the trivial examples used to show off good codegen, so it would not be unreasonable for the compiler to manage to fix a bad implementation of this. In addition, it would serve as a great excuse to finally learn how to write x86.

Basics

Here's the inner loop in C code:

void branching(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax, 
               uint64_t *zs, size_t zmax) {
  size_t i = 0, j = 0, k = 0;
  while (k < zmax) {
    if (xs[i] < ys[j]) {
      zs[k++] = xs[i++];
      if (i == xmax) { // x_done
        memcpy(zs + k, ys + j, 8 * (zmax - k));
        return; 
      }
    } else {
      zs[k++] = ys[j++];
      if (j == ymax) { // y_done
        memcpy(zs + k, xs + i, 8 * (zmax - k));
        return; 
      }
    }
  } // z_done
}

This seems to be a more or less straight forward textbook implementation of the procedure, so it will do fine as a benchmark. As a quick check before going any deeper into this we can use godbolt.org to see whether this experiment is even worth doing. Godbolts x86-64 gcc 8.3 with -O3 spits out this (annotations are by me):

branching(unsigned long*, unsigned long, unsigned long*, unsigned long, 
          unsigned long*, unsigned long):
        test    r9, r9       ; if (r9 == 0)
        je      .L15         ;   goto .L15
        push    r13          ;
        xor     eax, eax     ;
        xor     r11d, r11d   ; j = 0
        xor     r10d, r10d   ; i = 0
        push    r12          ;
        push    rbp          ;
        push    rbx          ;
        jmp     .L2          ;
.L17:
        add     r10, 1                        ; i++
        mov     QWORD PTR [r8-8+rax*8], rbp   ; zs[k-1] = xi
        cmp     r10, rsi                      ; if (i == xmax)
        je      .L16                          ;   goto .L16
.L6:
        cmp     r9, rax      ; if (k == zmax)
        je      .L1          ;   goto .L1
.L2:
        lea     r12, [rdi+r10*8]             ; calculate xs + i
        lea     r13, [rdx+r11*8]             ; calculate ys + j
        add     rax, 1                       ; k++
        mov     rbp, QWORD PTR [r12]         ; xi = xs[i]
        mov     rbx, QWORD PTR [r13+0]       ; yj = ys[j]
        cmp     rbp, rbx                     ; if (xi < yj)
        jb      .L17                         ;   goto .L17
        add     r11, 1                       ; j++
        mov     QWORD PTR [r8-8+rax*8], rbx  ; zs[k-1] = yj
        cmp     r11, rcx                     ; if (j != ymax)
        jne     .L6                          ;   goto .L6
        sub     r9, rax            ; y_done 
        pop     rbx                ;
        mov     rsi, r12           ;
        pop     rbp                ;
        lea     rdi, [r8+rax*8]    ;
        pop     r12                ;
        lea     rdx, [0+r9*8]      ;
        pop     r13                ;
        jmp     memcpy             ;
.L1:
        pop     rbx       ; z_done
        pop     rbp       ;
        pop     r12       ;
        pop     r13       ; 
        ret               ;
.L16:
        sub     r9, rax            ; x_done
        pop     rbx                ;
        mov     rsi, r13           ;
        pop     rbp                ;
        lea     rdi, [r8+rax*8]    ;
        pop     r12                ;
        lea     rdx, [0+r9*8]      ;
        pop     r13                ;
        jmp     memcpy             ;
.L15:
        ret

Plenty of branches!1

Now, maybe it turns out that it doesn't matter if we're branching or not and that the compiler knows best. We could guess that the reason we're still getting branches is because that's really the best way to go here. After all "you can't beat the compiler" seems to be the consensus in many programming circles. Let's try to write a version in C without exessive use of branching. Then perhaps the compiler will generate different code, and we can see what that difference amounts to in terms of running time. We can adopt Knuth's branchless version:

void nonbranching_but_branching(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax, 
                                uint64_t *zs, size_t zmax) {
  size_t i = 0, j = 0, k = 0;
  uint64_t xi = xs[i], yj = ys[j];
  while ((i < xmax) && (j < ymax) && (k < zmax)) {
    int64_t t = one_if_lt(xi - yj);
    yj = min(xi, yj);
    zs[k] = yj;
    i += t;
    xi = xs[i];
    t ^= 1;
    j += t;
    yj = ys[j];
    k += 1;
  }
  if (i == xmax)
    memcpy(zs + k, ys + j, 8 * (zmax - k));
  if (j == ymax)
    memcpy(zs + k, xs + i, 8 * (zmax - k));
}

What is going on, you might ask? The general idea is to first get min(xi, yj), and then have a number t that's 1 if xi < yj and 0 otherwise: we can add t to i, since t=1 if we just wrote xi to zs[k]. Then we can xor it with 1, effectively flipping 1 to 0 and 0 to 1, and then add t^1 to j; this causes either i or j to be incremented but not both. We used two convenience functions here, one_if_lt and min, both implemented straight forward with branching, hoping that the compiler will figure this out for us, now that the branches are much smaller.

Next, if we cheat a litte and assume that the highest bit in the numbers are never set we can get rid of those branches2:

void nonbranching(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax, 
                  uint64_t *zs, size_t zmax) {
  size_t i = 0, j = 0, k = 0;
  uint64_t xi = xs[i], yj = ys[j];
  while ((i < xmax) && (j < ymax) && (k < zmax)) {
    uint64_t neg = (xi - yj) >> 63;
    yj = neg * xi + (1 - neg) * yj;
    zs[k] = yj;
    i += neg;
    xi = xs[i];
    neg ^= 1;
    j += neg;
    yj = ys[j];
    k += 1;
  }
  if (i == xmax)
    memcpy(zs + k, ys + j, 8 * (zmax - k));
  if (j == ymax)
    memcpy(zs + k, xs + i, 8 * (zmax - k));
}

What is up with (xi - yj) >> 63 you may ask? This result is negative if xi < yj, and so it will overflow and its most significant bit will be set. Then we shift down logically (since we're using unsigned integers3) so the bits that are filled in are all zeroes. Since the width is 64, we effectively move the upper bit to the lowest position while setting all other bits to zero.

Knuth has another quirk, namely that his arrays usually points to the end of the array, and his indices are negative, going from -xmax up to 0 instead of the more standard going from 0 up to xmax. One consequence of this is that the termination check can be done with one comparison instead of three, by anding together the three indices: since they are negative they have their most significant bit set, unless zero. Here's both of the previous versions with this reversal trick:

void nonbranching_but_branching_reverse(uint64_t *xs, size_t xmax, 
                                        uint64_t *ys, size_t ymax, 
                                        uint64_t *zs, size_t zmax) {
  uint64_t *xse = xs + xmax;
  uint64_t *yse = ys + ymax;
  uint64_t *zse = zs + zmax;

  ssize_t i = -((ssize_t) xmax);
  ssize_t j = -((ssize_t) ymax);
  ssize_t k = -((ssize_t) zmax);

  uint64_t xi = xse[i], yj = yse[j];
  while (i & j & k) {
    uint64_t t = one_if_lt(xi - yj);
    yj = min(xi, yj);
    zse[k] = yj;
    i += t;
    xi = xse[i];
    t ^= 1;
    j += t;
    yj = yse[j];
    k += 1;
  }
  if (i == 0)
    memcpy(zse + k, yse + j, -8 * k);
  if (j == 0)
    memcpy(zse + k, xse + i, -8 * k);
}

void nonbranching_reverse(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax, 
                          uint64_t *zs, size_t zmax) {
  uint64_t *xse = xs + xmax;
  uint64_t *yse = ys + ymax;
  uint64_t *zse = zs + zmax;

  ssize_t i = -((ssize_t) xmax);
  ssize_t j = -((ssize_t) ymax);
  ssize_t k = -((ssize_t) zmax);

  uint64_t xi = xse[i], yj = yse[j];
  while (i & j & k) {
    uint64_t neg = (xi - yj) >> 63;
    yj = neg * xi + (1 - neg) * yj;
    zse[k] = yj;
    i += neg;
    xi = xse[i];
    neg ^= 1;
    j += neg;
    yj = yse[j];
    k += 1;
  }
  if (i == 0)
    memcpy(zse + k, yse + j, -8 * k);
  if (j == 0)
    memcpy(zse + k, xse + i, -8 * k);
}

Technically, I suppose we do assume that the length of the arrays are not >2**63, so that they fit in an ssize_t, but considering that the address space of x86-64 is not 64 bits, but merely 48 bits4, this is not a problem, even in theory.

Writing the ASM ourselves

Lastly, we can try to write the assembly ourselves. When translating the branch-free routine by Knuth into x86 there are a number of things to do. First we need to figure out how to get -1/0/+1 by comparing two variables, as MMIXs CMP instruction does. However, instead of trying to translate this line by line, which would end up with us having more instructions than needed, we should rather look more closely at what we're doing, so that we really understand the minimal amount of work that we have to do.

We only need to do two things: compare $x_i$ and $y_i$ and load the smaller into a register, and increment either i or j. The former can be done using cmovl, and the latter can be done in a similar fasion as Knuth does it, which is basically what we've been doing up to this point in C. This is the version I ended up with (here in inline-GCC asm format):

1: mov   %[minxy], %[yj]                     ;
   cmp   %[xi], %[yj]                        ; minxy = min(xi, yj)
   cmovl %[minxy], %[xi]                     ;
   mov   QWORD PTR [%[zse]+8*%[k]], %[minxy] ; zs[k] = minxy
   mov   %[t], 0                             ; t = 0
   cmovl %[t], %[one]                        ; if xi < yj: t = 1
   add   %[i], %[t]                          ; i += t
   mov   %[xi], QWORD PTR [%[xse]+8*%[i]]    ; xi = xs[i]
   xor   %[t], 1                             ; t ^= 1
   add   %[j], %[t]                          ; j += t
   mov   %[yj], QWORD PTR [%[yse]+8*%[j]]    ; yj = ys[j]
   add   %[k], 1                             ; k += 1
   mov   %[u], %[i]                          ; 
   and   %[u], %[j]                          ;
   test  %[u], %[k]                          ; if ((i & j & k) != 0)
   jnz   1b                                  ;   goto 1

There's a few quirks here, like having a couple of mov instructions in between the second conditional load and the instruction it conditions on, and the fact that cmovl couldn't take an immediate value, so I had to setup a register with only the value 1 in it. A sneaky detail to keep in mind is that when we set t = 0 we cannot use the trick of xoring t with itself, since this will change the flags, causing the subsequent cmovl to be wrong.

Now we can take a look at the assembly generated from some of the other fuctions by using objdump -d. Our own programs are compiled with -O3 -march=native. Here is the inner loop in nonbranching_reverse:

<nonbranching_reverse>:
1ef0:	mov    rax,rdi
1ef3:	sub    rax,rsi
1ef6:	shr    rax,0x3f
1efa:	mov    rdx,r8
1efd:	sub    rdx,rax
1f00:	imul   rdx,rsi
1f04:	imul   rdi,rax
1f08:	add    rbp,rax
1f0b:	xor    rax,0x1
1f0f:	add    rdi,rdx
1f12:	mov    QWORD PTR [r13+r12*8+0x0],rdi
1f17:	add    rcx,rax
1f1a:	inc    r12
1f1d:	mov    rax,rbp
1f20:	and    rax,r12
1f23:	mov    rdi,QWORD PTR [rbx+rbp*8]
1f27:	mov    rsi,QWORD PTR [r10+rcx*8]
1f2b:	test   rax,rcx
1f2e:	jne    1ef0 <nonbranching_reverse+0x40>

Sure looks a lot better than branching! This seems more or less reasonable, but we can see that the multiplication trickery that we used to avoid the min branch takes up some space here; presumably it also takes some time. Maybe one little branch isn't too bad though, and perhaps the compiler is more willingly to use conditional instructions if we use the ternary operator, like this:

void nonbranching_reverse_ternary(uint64_t *xs, size_t xmax, uint64_t *ys, size_t ymax, 
                                  uint64_t *zs, size_t zmax) {
  uint64_t *xse = xs + xmax;
  uint64_t *yse = ys + ymax;
  uint64_t *zse = zs + zmax;

  ssize_t i = -((ssize_t) xmax);
  ssize_t j = -((ssize_t) ymax);
  ssize_t k = -((ssize_t) zmax);

  uint64_t xi = xse[i], yj = yse[j];
  while (i & j & k) {
    uint64_t ybig = (xi - yj) >> 63;
    yj = ybig ? xi : yj;
    zse[k] = yj;
    i += ybig;
    xi = xse[i];
    ybig ^= 1;
    j += ybig;
    yj = yse[j];
    k += 1;
  }
  if (i == 0)
    memcpy(zse + k, yse + j, -8 * k);
  if (j == 0)
    memcpy(zse + k, xse + i, -8 * k);
}

This time, if we look at the assembly, we can see that the compiler is finally getting it: cmove!

2080:	mov    rax,yj                     ;
2083:	sub    rax,xi                     ;
2086:	shr    rax,0x3f                   ; t = (yj - xi) >> 63
208a:	cmove  yj,xi                      ; yj = t == 0 ? xi : yj
208e:	add    j,rax                      ; j += t
2091:	mov    QWORD PTR [zs+k*8],yj      ; z[k] = yj
2096:	xor    rax,0x1                    ; t ^= 1
209a:	inc    k                          ; k++
209d:	add    i,rax                      ; i += t
20a0:	mov    rax,k                      ; 
20a3:	and    rax,j                      ; t = k & j
20a6:	mov    yj,QWORD PTR [ys+j*8]      ; yj = ys[j]
20aa:	mov    xi,QWORD PTR [xs+i*8]      ; xi = xs[i]
20ae:	test   rax,i                      ; if ((i & j & k) != 0)
20b1:	jne    2080                       ; goto .2080

So we see it's really the same! Curiously, the compiler turned our code around to have t be 1 if xi was the bigger, whereas our ybig was 1 if yj was the bigger.

Results

And now for the results! We fill two arrays with random elements and run branching on it, such that we get the merged array back. This is used as the ground truth which all other variations are checked agaist, in case we have messed up. Then we use clock_gettime to measure the wall clock time that we spend, per method. The following is running time in milliseconds where both lists are 2**25 elements long, averaged over 100 runs; 10 iterations per seed and 10 different seeds (srand(i) for each iteration).

These are the numbers I got on a Intel i7-7500U@2.7GHz (avg +/- var):

branching:                          30.998 +/- 0.001
nonbranching_but_branching:         27.330 +/- 0.002
nonbranching:                       24.770 +/- 0.000
nonbranching_but_branching_reverse: 19.387 +/- 0.000
nonbranching_reverse:               20.015 +/- 0.000
nonbranching_reverse_ternary:       19.038 +/- 0.000
asm_nb_rev:                         18.987 +/- 0.001

I also ran the suite on another machine with a Intel i5-8250U@1.60GHz, in order to see if there would be any significant difference:

branching:                          31.405 +/- 0.034
nonbranching_but_branching:         27.646 +/- 0.097
nonbranching:                       27.894 +/- 0.021
nonbranching_but_branching_reverse: 22.760 +/- 0.040
nonbranching_reverse:               21.284 +/- 0.050
nonbranching_reverse_ternary:       19.299 +/- 0.002
asm_nb_rev:                         19.793 +/- 0.009

Interestingly, on this CPU our assembly is slightly slower than the ternary version; I guess this is due to us using a cmovl where the compiler generated version used the shifting trick.

Bonus: Sorting

We can't possibly have done all this merging without making a proper mergesort in the end! Luckily for us, the merge part is really the only difficult part of the routine:

void merge_sort(uint64_t *xs, size_t n, uint64_t *buf) {
  if (n < 2) return;
  size_t h = n / 2;
  merge_sort(xs, h, buf);
  merge_sort(xs + h, n - h, buf + h);
  merge(xs, h, xs + h, n - h, buf, n);
  memcpy(xs, buf, 8 * n);
}

Unfortunately we have to merge to a buffer and then memcpy it back. Perhaps this is fixable: we can make the sorting routine either put the result in xs or in buf, and by having the recursive calls say which we can merge into the other, assuming both recursive calls agree(!!5). That is, if the recursive calls say that the sorted subarrays are in xs, we merge into buf and tell our caller that our result is in buf. At the end, we just need to make sure that the final sorted numbers are in xs.

void _sort_asm(uint64_t *xs, size_t n, uint64_t *buf, int *into_buf) {
  if (n < 2) {
    *into_buf = 0;
    return;
  }
  size_t h = n / 2;
  int res_in_buf;
  _sort_asm(xs, h, buf, &res_in_buf); // WARNING: `res_in_buf` for the two calls is needs
  _sort_asm(xs + h, n - h, buf + h, &res_in_buf); // not be the same in the real world!
  *into_buf = res_in_buf ^ 1;
  if (res_in_buf)
    asm_nb_rev(buf, h, buf + h, n - h, xs, n);
  else
    asm_nb_rev(xs, h, xs + h, n - h, buf, n);
}

void sort_asm(uint64_t *xs, size_t n, uint64_t *buf) {
  int res_in_buf;
  _sort_asm(xs, n, buf, &res_in_buf);
  if (res_in_buf) {
    memcpy(xs, buf, 8 * n);
  }
}

and similar, for the other variants. You might see the branch and wonder if we can remove it --- I tried, by making an array {xs, buf} and index it with res_in_buf, but it caused a minor slowdown: maybe some branching is fine after all.

Here are the running times:

                                         i7-7500U              i5-8250U
sort_branching:                          369.479 +/- 0.047     393.762 +/- 0.082
sort_nonbranching_but_branching:         324.337 +/- 0.014     337.120 +/- 0.099
sort_nonbranching:                       325.658 +/- 0.028     352.802 +/- 0.120
sort_nonbranching_but_branching_reverse: 279.237 +/- 0.164     287.799 +/- 0.154
sort_nonbranching_reverse:               283.927 +/- 0.033     299.277 +/- 0.929
sort_nonbranching_reverse_ternary:       270.668 +/- 0.009     278.644 +/- 1.677
sort_asm_nb_rev:                         270.228 +/- 0.009     281.657 +/- 0.360

If you would like to run the suite yourself, the git repo is avaiable here.

Thanks for reading.

Footnotes

  1. Originally I had omitted the _done parts, and the code was much cleaner, and I'm not sure why having it in complicates this that much. Also, why is k incremented before storing zs[k] so that we have to store zs[k-1] instead?

  2. Curiously, if we change from uint64_t to int64_t and use ((a-b)>>63)&1 for the test we do not depend on the magnitudes of the numbers (as the compiler can assume signed overflow will not happen); also the and never makes it to the assembly, and we still use logical instead of arithmetic shift.

  3. The alternative is arithmetic shift in which the sign bit is propagated down. In this case we would end up with either all zeroes or all ones.

  4. https://en.wikipedia.org/wiki/X86-64#Virtual_address_space_details

  5. This is really only the case if n is a power of two: otherwise you'll have two siblings in the call tree with different ns, and this difference will cause two leaf nodes to be at different depths, which in turn will make them "out of sync".

This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License