Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize heapq.merge with for-break(-else) pattern? #108880

Open
pochmann opened this issue Sep 4, 2023 · 11 comments
Open

Optimize heapq.merge with for-break(-else) pattern? #108880

pochmann opened this issue Sep 4, 2023 · 11 comments
Labels
performance Performance or resource usage stdlib Python modules in the Lib dir type-feature A feature request or enhancement

Comments

@pochmann
Copy link
Contributor

pochmann commented Sep 4, 2023

Feature or enhancement

Has this already been discussed elsewhere?

No response given

Links to previous discussion of this feature:

https://discuss.python.org/t/optimize-heapq-merge-with-for-break-else-pattern/32931?u=pochmann
(@serhiy-storchaka said to do this on GitHub instead)

Proposal:

@rhettinger just mentioned a willingness for optimized heapq.merge:

How messy and convoluted are we willing to make the code to save a few cycles? In some places like heapq.merge and random.__init_subclass__, the answer is that we go quite far. Elsewhere, we aim for simplicity.

I propose to optimize heapq.merge by using what I call the for-break(-else) pattern in order to get only the next element of an iterator (using an unconditionalbreak), which in my experience is significantly faster than calling next() or __next__ and can also lead to simpler code (depends on the case).

Benchmark for merging three sorted lists of 10,000 elements each:

  6.94 ± 0.05 ms  merge_proposal
  7.31 ± 0.07 ms  merge
Python: 3.11.4 (main, Jun 24 2023, 10:18:04) [GCC 13.1.1 20230429]

Here's a comparison with the current implementation:

Initialization: Put the non-empty iterators into the heap (note I include the iterator itself instead of its __next__):

Current:

        for order, it in enumerate(map(iter, iterables)):
            try:
                next = it.__next__
                h_append([next(), order * direction, next])
            except StopIteration:
                pass

Proposal:

        for order, it in enumerate(map(iter, iterables)):
            for value in it:
                h_append([value, order * direction, it])
                break

Merging while multiple iterators remain:

Current:

        while len(h) > 1:
            try:
                while True:
                    value, order, next = s = h[0]
                    yield value
                    s[0] = next()           # raises StopIteration when exhausted
                    _heapreplace(h, s)      # restore heap condition
            except StopIteration:
                _heappop(h)                 # remove empty iterator

Proposal:

        while len(h) > 1:
            while True:
                value, order, it = s = h[0]
                yield value
                for s[0] in it:
                    _heapreplace(h, s)      # restore heap condition
                    break
                else:
                    _heappop(h)             # remove empty iterator
                    break

End when only one iterator remains:

Current:

            # fast case when only a single iterator remains
            value, order, next = h[0]
            yield value
            yield from next.__self__

Proposal:

            # fast case when only a single iterator remains
            value, order, it = h[0]
            yield value
            yield from it
Benchmark script:

Attempt This Online!

import random
from timeit import timeit
from statistics import mean, stdev
from collections import deque
import sys
from heapq import *


def merge_proposal(*iterables, key=None, reverse=False):
    '''Merge multiple sorted inputs into a single sorted output.

    Similar to sorted(itertools.chain(*iterables)) but returns a generator,
    does not pull the data into memory all at once, and assumes that each of
    the input streams is already sorted (smallest to largest).

    >>> list(merge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25]))
    [0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25]

    If *key* is not None, applies a key function to each element to determine
    its sort order.

    >>> list(merge(['dog', 'horse'], ['cat', 'fish', 'kangaroo'], key=len))
    ['dog', 'cat', 'fish', 'horse', 'kangaroo']

    '''

    h = []
    h_append = h.append

    if reverse:
        _heapify = _heapify_max
        _heappop = _heappop_max
        _heapreplace = _heapreplace_max
        direction = -1
    else:
        _heapify = heapify
        _heappop = heappop
        _heapreplace = heapreplace
        direction = 1

    if key is None:
        for order, it in enumerate(map(iter, iterables)):
            for value in it:
                h_append([value, order * direction, it])
                break
        _heapify(h)
        while len(h) > 1:
            while True:
                value, order, it = s = h[0]
                yield value
                for s[0] in it:
                    _heapreplace(h, s)      # restore heap condition
                    break
                else:
                    _heappop(h)             # remove empty iterator
                    break
        if h:
            # fast case when only a single iterator remains
            value, order, it = h[0]
            yield value
            yield from it
        return

    # Omitted the code for non-None key case
        

funcs = merge, merge_proposal

n = 10 ** 4
iterables = [
    sorted(random.choices(range(n), k=n))
    for _ in range(3)
]

expect = list(merge(*iterables))
for f in funcs:
    result = list(f(*iterables))
    print(result == expect, f.__name__)
                  
times = {f: [] for f in funcs}
def stats(f):
    ts = [t * 1e3 for t in sorted(times[f])[:5]]
    return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} ms '

for _ in range(100):
    for f in funcs:
        t = timeit(lambda: deque(f(*iterables), 0), number=1)
        times[f].append(t)

for f in sorted(funcs, key=stats):
    print(stats(f), f.__name__)

print('Python:', sys.version)
@pochmann pochmann added the type-feature A feature request or enhancement label Sep 4, 2023
@serhiy-storchaka
Copy link
Member

Is this on debug or optimized build of Python?

What are the results for short lists? What if one list is much shorter than others? What if merge 100 lists of different length (1, 2, ... 100 elements) in order of increasing or decreasing their length? Try to find examples for which the worst examples for the new code, and we will see whether such cases can be ignored.

@serhiy-storchaka serhiy-storchaka added the performance Performance or resource usage label Sep 4, 2023
@pochmann
Copy link
Contributor Author

pochmann commented Sep 4, 2023

Extended benchmark with other types (converting the input lists to those types beforehand):

list
  6.61 ± 0.02 ms  merge_proposal
  7.10 ± 0.06 ms  merge

tuple
  6.46 ± 0.01 ms  merge_proposal
  6.96 ± 0.02 ms  merge

dict
  4.23 ± 0.03 ms  merge_proposal
  4.63 ± 0.03 ms  merge

deque
  6.39 ± 0.09 ms  merge_proposal
  6.89 ± 0.02 ms  merge

generator
  7.49 ± 0.04 ms  merge_proposal
  7.96 ± 0.03 ms  merge

string
  7.99 ± 0.01 ms  merge_proposal
  8.71 ± 0.01 ms  merge

class Iterator
  8.34 ± 0.08 ms  merge
 11.38 ± 0.07 ms  merge_proposal

Python: 3.11.4 (main, Jun 24 2023, 10:18:04) [GCC 13.1.1 20230429]

The last one is an iterator with a Python __iter__ running return self. I myself rarely ever do this, even with custom iterable classes I think I'd rather make their __iter__ a generator, not return a custom iterator object from a custom iterator class.

class Iterator:
    def __init__(self, iterable):
        self.iter = iter(iterable)
    def __iter__(self):
        return self
    def __next__(self):
        return next(self.iter)

(Maybe if operator.identity existed, implemented in C, we could do __iter__ = identity and it could be fast... Or if Python recognized and optimized such functions...)

Full code:

Attempt This Online!

import random
from timeit import timeit
from statistics import mean, stdev
from collections import deque
import sys
from heapq import *


def merge_proposal(*iterables, key=None, reverse=False):
    '''Merge multiple sorted inputs into a single sorted output.

    Similar to sorted(itertools.chain(*iterables)) but returns a generator,
    does not pull the data into memory all at once, and assumes that each of
    the input streams is already sorted (smallest to largest).

    >>> list(merge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25]))
    [0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25]

    If *key* is not None, applies a key function to each element to determine
    its sort order.

    >>> list(merge(['dog', 'horse'], ['cat', 'fish', 'kangaroo'], key=len))
    ['dog', 'cat', 'fish', 'horse', 'kangaroo']

    '''

    h = []
    h_append = h.append

    if reverse:
        _heapify = _heapify_max
        _heappop = _heappop_max
        _heapreplace = _heapreplace_max
        direction = -1
    else:
        _heapify = heapify
        _heappop = heappop
        _heapreplace = heapreplace
        direction = 1

    if key is None:
        for order, it in enumerate(map(iter, iterables)):
            for value in it:
                h_append([value, order * direction, it])
                break
        _heapify(h)
        while len(h) > 1:
            while True:
                value, order, it = s = h[0]
                yield value
                for s[0] in it:
                    _heapreplace(h, s)      # restore heap condition
                    break
                else:
                    _heappop(h)             # remove empty iterator
                    break
        if h:
            # fast case when only a single iterator remains
            value, order, it = h[0]
            yield value
            yield from it
        return

    # Omitted the code for non-None key case
        

funcs = merge, merge_proposal

n = 10 ** 4
iterables = [
    sorted(random.choices(range(n), k=n))
    for _ in range(3)
]

expect = list(merge(*iterables))
for f in funcs:
    result = list(f(*iterables))
    print(result == expect, f.__name__)

class Iterator:
    def __init__(self, iterable):
        self.iter = iter(iterable)
    def __iter__(self):
        return self
    def __next__(self):
        return next(self.iter)

types = [
    ('list', list),
    ('tuple', tuple),
    ('dict', dict.fromkeys),
    ('deque', deque),
    ('generator', lambda iterable: (x for x in iterable)),
    ('string', lambda iterable: ''.join(map(chr, iterable))),
    ('class Iterator', Iterator),
]

for label, converter in types:
    print(label)

    times = {f: [] for f in funcs}
    def stats(f):
        ts = [t * 1e3 for t in sorted(times[f])[:5]]
        return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} ms '

    for _ in range(100):
        for f in funcs:
            converted = list(map(converter, iterables))
            t = timeit(lambda: deque (f(*converted), 0), number=1)
            times[f].append(t)

    for f in sorted(funcs, key=stats):
        print(stats(f), f.__name__)
    print()

print('Python:', sys.version)

@pochmann
Copy link
Contributor Author

pochmann commented Sep 4, 2023

Is this on debug or optimized build of Python?

I don't know. I don't have a functional PC at the moment, did this on the linked ATO site now. But I got similar results a while ago on my Windows PC with the standard installer from python.org.

What are the results for short lists? What if one list is much shorter than others?

How much (if this isn't covered by your next request)?

What if merge 100 lists of different length (1, 2, ... 100 elements) in order of increasing or decreasing their length?

Increasing lengths:

  2.47 ± 0.02 ms  merge_proposal
  2.64 ± 0.01 ms  merge

Decreasing lengths:

  2.53 ± 0.01 ms  merge_proposal
  2.67 ± 0.04 ms  merge

Used setup:

n = 10 ** 4
iterables = [
    sorted(random.choices(range(n), k=length))
    for length in range(1, 101)[::-1]
]

Try to find examples for which the worst examples for the new code, and we will see whether such cases can be ignored.

See "class Iterator" in the updated extended benchmark above:

@tim-one
Copy link
Member

tim-one commented Sep 4, 2023

Noting that the proposal is actually much slower under PyPy (7.3.12 on 64-bit WIndows)., leaving them slower than under current CPython. That's not uncommon for "clever" code - PyPy likes code as straightforward and "primitive" as possible.

Wish I could say something more helpful, but don't know enough. Offhand, PyPy is fiercely focused on optimizing loops, and a loop that - by design - only ever goes around at most once (for s[0] in it:) must fight everything it expects about loops. In any case, with such a small loop count, I believe (but don't know) it won't even try to JITify the loop body.

True merge
True merge_proposal
list
  4.41 ± 0.02 ms  merge
 11.32 ± 0.05 ms  merge_proposal

tuple
  4.82 ± 0.07 ms  merge
 11.30 ± 0.10 ms  merge_proposal

dict
  3.11 ± 0.04 ms  merge
  6.79 ± 0.03 ms  merge_proposal

deque
  4.94 ± 0.03 ms  merge
 11.35 ± 0.06 ms  merge_proposal

generator
  5.05 ± 0.05 ms  merge
 11.48 ± 0.03 ms  merge_proposal

string
  5.47 ± 0.04 ms  merge
 12.49 ± 0.04 ms  merge_proposal

class Iterator
  5.05 ± 0.01 ms  merge
 11.26 ± 0.03 ms  merge_proposal

Python: 3.10.12 (af44d0b8114cb82c40a07bb9ee9c1ca8a1b3688c, Jun 15 2023, 15:42:22)
[PyPy 7.3.12 with MSC v.1929 64 bit (AMD64)]

@tim-one
Copy link
Member

tim-one commented Sep 5, 2023

Sorry - fighting more illusions here! PyPy shows a very similar "slowdown" if I copy/paste the current CPython heapq.py's merge() source into the benchmark and change its name there to merge_proposal(). The heapq.py that ships with PyPy 7.3.12 on Windows has a very different merge implementation than CPython's, and doesn't actually use heaps at all.

But I'm at a loss to account for where it comes from. CPython has never had an implementation like it, and I don't understand PyPy's workflow well enough to figure out how to access its checkin history.

In any case, I was comparing two entirely different merge() algorithms, and while the results demonstrate that PyPy's implementation works much better for PyPy, nothing can actually be concluded about your proposed change to CPython's implementation.

@tim-one
Copy link
Member

tim-one commented Sep 5, 2023

OK, apples to apples under PyPy. If I copy current CPython main heapq.merge source into the benchmark, then PyPy's implementation of merge() becomes irrelevant, and merge_proposal() has a small (1-2%) but consistent speed advantage over the current CPython implementation.

@tim-one
Copy link
Member

tim-one commented Sep 5, 2023

Huh! I thought that code rang a bell. Looks like PyPy adopted a merge() approach Dennis Sweeney originally worked on for CPython, but which was rejected as not worth all the bother after a lot of back-&-forth.

bpo-389381

Footnotes

  1. Fix GitHub issues link (edited by @erlend-aasland)

@pochmann
Copy link
Contributor Author

pochmann commented Sep 5, 2023

Thanks. From now on I'll also copy the original to ensure apples-to-apples.

I've seen Dennis's efforts when I played with my own completely different alternatives. I might've given up on them when I saw the rejection of Dennis's, don't remember. Maybe there's hope for this one, as it's not completely different but just a small change. And I think it's not just faster but also nicer, at least the initialization and end parts. More so for the version with key:

Current:

    for order, it in enumerate(map(iter, iterables)):
        try:
            next = it.__next__
            value = next()
            h_append([key(value), order * direction, value, next])
        except StopIteration:
            pass

Proposal:

    for order, it in enumerate(map(iter, iterables)):
        for value in it:
            h_append([key(value), order * direction, value, it])
            break

I imagine someone not used to such for-break(-else) usage (i.e., most people) might still prefer the current version, at least of the merging part. Although getting __next__ and then __self__ is also unusual, as maybe is the try-except around the whole inner while-loop with the apparent need to comment where StopIteration might come from. I guess that's part of what Raymond meant with "messy and convoluted" (which btw made me finally propose this after all, despite its unusualness :-).

@tim-one
Copy link
Member

tim-one commented Sep 6, 2023

The pypy confusion is all on me. PyPy usually uses the Python code that ships with CPython, but not always, and I was careless in not checking that first in this case (indeed, it uses CPython's heapq.py for everything except heapq.merge()).

Every way of interleaving a collection of iterators some of which may "end early" has some kind of "boy - that sure looks weird at first!" wart. The single weirdest thing to me here is the current code's:

            yield from next.__self__

I've never used that, and don't expect I ever will, but it's obvious from context what it must do, and I trust that Raymond got the inscrutable details right 😄.

Dennis was aiming at much bigger speedups, and at least under PyPy he got them. At least under-appreciated and perhaps wholly un-appreciated here: building a heap with tuple entries adds a steep time penalty because of the "tuple" part. While the heap code only cares about __lt__() results, and many types implement __lt__ efficiently, every kind of tuple comparison begins by finding the longest all-equal prefix first (generally true of lexicographically ordered container-type comparisons).

So, e.g., if we're merging ints, the merge code never does i < j directly. The ints are inside tuples, so tuple.__lt__ will first ask if the ints are equal. If the answer is "no", it will do a second comparison to see whether "less than?" obtains.

One thing I learned from implementing sorting is that PyObject_RichCompareBool() is frickin' expensive 😉, and putting "the real value" inside tuples can double the number of such calls needed (well, even triple: one call to tuple.__lt__, another to int.__eq__, and a third to int.__lt__).

If I take the current CPython "no-key" merge code, and fiddle it just a little to put instances of this class on the heap instead:

class Item:
    __slots__ = 'val', 'next'

    def __init__(self, val, next):
        self.val = val
        self.next = next

    def __lt__(x, y):
        return x.val < y.val

then the heap code calls Item.__lt__ directly. The object overhead makes it slower under CPython, but it's significantly faster under PyPy (which has much lower overhead for using simple objects).

This is part of why Dennis's code is so much quicker under PyPy too (it's also avoiding "triple comparison" overheads for embedding the real values of interest in tuples).

But short of all that, your proposed code here looks like simple changes for modest speed gain. +0.5 from me.

@tim-one
Copy link
Member

tim-one commented Sep 7, 2023

Some food for thought. Merging has some properties that heaps can be specialized for. One was mentioned before: find a way to convice the code to call __lt__ directly on the key fields rather than indirectly via a tuple comparison.

Another: for reasons explained in code comments, for general use the Python heap code first moves heap "holes" all the way to leaf positions, and then puts the next value into that hole and bubbles it up. That's because for "random" entry orders, and overwhelmingly so in an in-place heap sort, "the next value" is most likely to end up closer to a leaf than to the root.

But a different common way is to put the next value into the hole at once, and move it down a level only if one of the child nodes is smaller. This has a much better best case (could be the next value belongs in the hole at once), the same worst case (about 2 * lg n compares), and worse expected case for "random" (etc) insertion orders.

But in the specific application of merging, we're replacing the smallest overall value with the next-smallest in the stream that winner came from. So it's probably close to the value we just yielded, and so probably belongs high in the heap. That makes the alternative way much more attractive. Indeed, the great attraction of tournament-style "loser trees" for merging is that replacement always takes exactly lg n compares. But the alternative heap method can get away with fewer at times.

Anyway, how practical is this? Very, but under PyPy. The only thing that needs to change is the implementation of heapreplace() used. I'll include a first stab at that below. It's significantly faster than the status quo under PyPy, and if I boost the number of streams merged in your benchmark code, from 3 to (the more typical, I think) 16, well over twice as fast.

Under CPython it loses. That's because CPython is using a C-coded heapreplace(), and I haven't coded the alternative in C (it's using the Python implementation below). PyPy is using Python code for everything in the "before" and "after" runs.

Here's the code. The only change to merge is replacing heapreplace with heapaltreplace. So this is orthogonal to what Stefan is pursuing.

def heapaltreplace(heap, item):
    endpos = len(heap)
    pos = 0
    # Don't do tuple compares! Compare the primary and secondary keys directly.
    # And, as usual, stick to doing only __lt__ compares.
    key1, key2 = item[:2]
    childpos = 2*pos + 1    # leftmost child position
    while childpos < endpos:
        # Set childpos to index of smaller child, and ck1 & ck2 to
        # that child's keys.
        ck1, ck2 = heap[childpos][:2]
        rightpos = childpos + 1
        if rightpos < endpos:
            rk1, rk2 = heap[rightpos][:2]
            if rk1 < ck1 or (rk1 == ck1 and rk2 < ck2):
                childpos = rightpos
                ck1, ck2 = rk1, rk2
        if ck1 < key1 or (ck1 == key1 and ck2 < key2):
            heap[pos] = heap[childpos]
            pos = childpos
            childpos = 2*pos + 1
        else:
            break
    # The slot at pos is empty now.
    heap[pos] = item

@pochmann
Copy link
Contributor Author

pochmann commented Sep 7, 2023

Those are some good ideas, and I've tried some more ideas myself and I'm quite interested to pursue them all later, but here I'll stick with just the for-break-else one, hoping it's small enough to actually get adopted.

@iritkatriel iritkatriel added the stdlib Python modules in the Lib dir label Nov 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance or resource usage stdlib Python modules in the Lib dir type-feature A feature request or enhancement
Projects
None yet
Development

No branches or pull requests

4 participants