Skip to content

gh-113947: Speed up Counter.__eq__ #113948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from

Conversation

keithasaurus
Copy link
Contributor

@keithasaurus keithasaurus commented Jan 11, 2024

What I did

  • rewrite Counter.eq to be more efficient
  • built with ./configure --enable-optimizations
  • ran a benchmark which showed ~1.45x efficiency gain.

Benchmark

(running on an m1 Mac -- code is pasted below)

main branch
(format is <iteration index>: <duration in seconds>)

0: 0.32239580154418945
1: 0.4231119155883789
2: 2.254210948944092
3: 2.2577903270721436
4: 0.34040093421936035
5: 1.4848048686981201
6: 0.8643209934234619
total duration: 7.947035789489746

PR feature branch

0: 0.3302881717681885
1: 0.37781810760498047
2: 1.4705991744995117
3: 1.4680209159851074
4: 0.24953103065490723
5: 1.0074560642242432
6: 0.5758671760559082
total duration: 5.479580640792847

Benchmark code

from collections import Counter
from typing import Any
from time import time


def benchmark(iterations: int, a: Counter[Any], b: Counter[Any]) -> float:
    start = time()
    for i in range(iterations):
        a == b
    return time() - start


total_duration = 0.0
for i, (a, b) in enumerate([
    (Counter(), Counter()),
    (Counter(), Counter("abcdefghijklmnop")),
    (Counter("abcdefghijklmnop"), Counter("abcdefghijklmnop")),
    (Counter("abcdefghijklmnop"), Counter("abcdefghijklmnop")),
    (Counter("aaaaaaaaaaaaaaaa"), Counter("abcdefghijklmnop")),
    (Counter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), Counter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
    (Counter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10]), Counter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
]):
    result = benchmark(1000000, a, b)
    print(f"{i}:", result)
    total_duration += result

print("total duration:", total_duration)

@ghost
Copy link

ghost commented Jan 11, 2024

All commit authors signed the Contributor License Agreement.
CLA signed

@bedevere-app
Copy link

bedevere-app bot commented Jan 11, 2024

Most changes to Python require a NEWS entry. Add one using the blurb_it web app or the blurb command-line tool.

If this change has little impact on Python users, wait for a maintainer to apply the skip news label instead.

@bedevere-app
Copy link

bedevere-app bot commented Jan 11, 2024

Most changes to Python require a NEWS entry. Add one using the blurb_it web app or the blurb command-line tool.

If this change has little impact on Python users, wait for a maintainer to apply the skip news label instead.

@Eclips4 Eclips4 changed the title Speed up Counter.__eq__ gh-113947: Speed up Counter.__eq__ Jan 11, 2024
@serhiy-storchaka serhiy-storchaka self-requested a review January 11, 2024 21:21
Copy link
Member

@serhiy-storchaka serhiy-storchaka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. The main effect is perhaps from getting rid of a generator.

Please add a simple NEWS entry.

@AA-Turner
Copy link
Member

To @terryjreedy's point, I benchmarked a variety of permutations.

Function Duration
seq_loop 859.5 ms
nested_loop 970.7 ms
nested_loop_keys 1223.3 ms
seq_loop_keys 1232.2 ms
all_map 1685.2 ms
original 1786.8 ms
all_items 1819.1 ms
seq_any 1833.0 ms
all_explode 1879.0 ms
nested_any 1952.3 ms
seq_any_keys 2057.6 ms
nested_any_keys 2125.4 ms
seq_any_listcomp 2339.0 ms
nested_any_listcomp 2384.2 ms
all_listcomp 3632.4 ms

This ran 10,000 iterations each of a modified version of @keithasaurus's script, including a very large Counter.

Benchmarking
import string
from collections import Counter
from operator import itemgetter
from time import perf_counter_ns

ALPHA_100 = string.ascii_letters * 100
A_100 = "a" * len(ALPHA_100)
COUNTERS = (
    (Counter(), Counter()),
    (Counter(), Counter("abcdefghijklmnop")),
    (Counter("abcdefghijklmnop"), Counter("abcdefghijklmnop")),
    (Counter("abcdefghijklmnopppppp"), Counter("abcdefghijklmnop")),
    (Counter("abcdefghijklmnop"), Counter("abcdefghijklmnopppppp")),
    (Counter("aaaaaaaaaaaaaaaa"), Counter("abcdefghijklmnop")),
    (Counter("abcdefghijklmnop"), Counter("aaaaaaaaaaaaaaaa")),
    (Counter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), Counter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
    (Counter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10]), Counter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])),
    (Counter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), Counter([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10])),
    (Counter(ALPHA_100), Counter(A_100)),
    (Counter(A_100), Counter(ALPHA_100)),
)
EQS = [a == b for (a, b) in COUNTERS]


def benchmark(iterations, a, b):
    start = perf_counter_ns()
    for i in range(iterations):
        a == b
    return perf_counter_ns() - start


def run_bench(iterations):
    total_duration = 0
    for (a, b) in COUNTERS:
        total_duration += benchmark(iterations, a, b)

    outcome = "" if [a == b for (a, b) in COUNTERS] == EQS else "(incorrect!)"
    print(f"  total duration: {total_duration / 10**6:0>5.1f} ms {outcome}")
    return total_duration


def eq_original(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    return all(self[e] == other[e] for c in (self, other) for e in c)


def eq_all_listcomp(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    return all([self[e] == other[e] for c in (self, other) for e in c])


def eq_all_explode(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    return all(self[k] == other[k] for k in self) and all(other[k] == self[k] for k in other)


def eq_all_items(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    return all(v == other[k] for k, v in self.items()) and all(v == self[k] for k, v in other.items())


def eq_all_map(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    return all(map((lambda t: t[1] == other[t[0]]), self.items())) and all(map((lambda t: t[1] == self[t[0]]), other.items()))


def eq_nested_loop(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    for obj1, obj2 in ((self, other), (other, self)):
        for k, v in obj1.items():
            if v != obj2[k]:
                return False
    return True


def eq_seq_loop(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    for k, v in self.items():
        if v != other[k]:
            return False
    for k, v in other.items():
        if v != self[k]:
            return False
    return True


def eq_nested_any(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    for obj1, obj2 in ((self, other), (other, self)):
        if any(v != obj2[k] for k, v in obj1.items()):
            return False
    return True


def eq_seq_any(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    if any(v != other[k] for k, v in self.items()):
        return False
    if any(v != self[k] for k, v in other.items()):
        return False
    return True


def eq_nested_loop_keys(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    if self.keys() != other.keys():
        return False
    for obj1, obj2 in ((self, other), (other, self)):
        for k, v in obj1.items():
            if v != obj2[k]:
                return False
    return True


def eq_seq_loop_keys(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    if self.keys() != other.keys():
        return False
    for k, v in self.items():
        if v != other[k]:
            return False
    for k, v in other.items():
        if v != self[k]:
            return False
    return True


def eq_nested_any_keys(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    if self.keys() != other.keys():
        return False
    for obj1, obj2 in ((self, other), (other, self)):
        if any(v != obj2[k] for k, v in obj1.items()):
            return False
    return True


def eq_seq_any_keys(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    if self.keys() != other.keys():
        return False
    if any(v != other[k] for k, v in self.items()):
        return False
    if any(v != self[k] for k, v in other.items()):
        return False
    return True


def eq_nested_any_listcomp(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    for obj1, obj2 in ((self, other), (other, self)):
        if any([v != obj2[k] for k, v in obj1.items()]):
            return False
    return True


def eq_seq_any_listcomp(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if any([v != other[k] for k, v in self.items()]):
        return False
    if any([v != self[k] for k, v in other.items()]):
        return False
    return True


results = {}
func = name = None
for func in (func for name, func in globals().items() if name.startswith('eq_')):
    name = func.__name__
    print(name)
    Counter.__eq__ = func
    results[name[3:]] = run_bench(iterations=10**5)

longest_name = len(max(results, key=len))
print("\nSorted timings\n" + "=" * (longest_name+10))
for f, d in sorted(results.items(), key=itemgetter(1)):
    print(f"{f: >{longest_name}}: {d / 10**6:0>5.1f} ms")

The second fastest function is the one suggested in this PR. The fastest function improves on that time by 11-12% (for a total 52% speedup) by unrolling the loop:

def __eq__(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    for k, v in self.items():
        if v != other[k]:
            return False
    for k, v in other.items():
        if v != self[k]:
            return False
    return True

I think we could also consider changing Counter.__ge__ and Counter.__le__ here, as they both can be similarly sped up.

A

@keithasaurus
Copy link
Contributor Author

keithasaurus commented Jan 12, 2024

@AA-Turner Yes, I actually had tried that version, but I wasn't sure two loops would be as welcome from a maintainability point of view.

It may be worth noting there is a potential further "optimization" I couldn't quite find.

def __eq__(self, other):
    'True if all counts agree. Missing counts are treated as zero.'
    if not isinstance(other, Counter):
        return NotImplemented
    for k, v in self.items():
        if v != other[k]:
            return False
    # at this point we've already covered the keys common to both `self` and `other`, 
    # so we really only need to check that any keys that aren't in `self` are equal to zero
    for k, v in other.items():
        if k not in self and v != 0:
            return False
    return True

However, it didn't show an improvement in performance -- every iteration is still doing a dict lookup and an equality check. I'm not sure if there's a way to improve it.

I'll update to the two loops.

@bedevere-app
Copy link

bedevere-app bot commented Jan 12, 2024

Most changes to Python require a NEWS entry. Add one using the blurb_it web app or the blurb command-line tool.

If this change has little impact on Python users, wait for a maintainer to apply the skip news label instead.

@bedevere-app
Copy link

bedevere-app bot commented Jan 12, 2024

Most changes to Python require a NEWS entry. Add one using the blurb_it web app or the blurb command-line tool.

If this change has little impact on Python users, wait for a maintainer to apply the skip news label instead.

@keithasaurus
Copy link
Contributor Author

keithasaurus commented Jan 12, 2024

I updated to the two for loops. I've also updated the code to do use dict's eq check first, which significantly speeds up the equal case -- this is because the dict equality is done in c-compiled code (right?). This slightly slows the non-matching case but further increases the efficiency according to this benchmark by ~2.3x.

Two for loops

0: 0.22596406936645508
1: 0.27464795112609863
2: 1.3507919311523438
3: 1.3493068218231201
4: 0.1584930419921875
5: 0.912395715713501
6: 0.4830331802368164
total duration: 4.7546327114105225

two for loops + super() equality check

0: 0.14082098007202148
1: 0.33873701095581055
2: 0.26291418075561523
3: 0.26647114753723145
4: 0.21964073181152344
5: 0.19705605506896973
6: 0.6212449073791504
total duration: 2.0468850135803223

So, according to this benchmark, we'd be at about a 4x speedup overall.

@AA-Turner
Copy link
Member

Another potential:

return {k: v for k, v in self.items() if v} == {k: v for k, v in other.items() if v}

I'm on mobile so can't test, but this may benefit due to the comprehension inlining in Python 3.12 and later.

A

@keithasaurus
Copy link
Contributor Author

Another potential:

return {k: v for k, v in self.items() if v} == {k: v for k, v in other.items() if v}

I'm on mobile so can't test, but this may benefit due to the comprehension inlining in Python 3.12 and later.

A

I would be surprised if the inlining of comprehensions was fast enough to overcome a lack of short-circuiting, but definitely open to it if it works.

@keithasaurus
Copy link
Contributor Author

@AA-Turner For the dict comprehension approach I got:

0: 0.31111598014831543
1: 0.8031868934631348
2: 1.4441249370574951
3: 1.4424488544464111
4: 0.8289532661437988
5: 0.9804871082305908
6: 0.9872369766235352
total duration: 6.797554016113281

I expect the slowness is in part due to the need to instantiate new dicts, as well as no short-circuiting in the False case.

@AA-Turner
Copy link
Member

I think this is good to merge, will hold off for a couple of days for any concerns to be raised.

A

@AA-Turner
Copy link
Member

@keithasaurus just to note there's no need to frequently merge with HEAD unless there are conflicting changes (which is fairly rare) -- it takes up additional CI resource for extra commits etc, which we should attempt to avoid.

A

@HarryLHW
Copy link
Contributor

HarryLHW commented Jan 16, 2024

I just noticed that if you use if v and k not in self:, you will make Counter() == Counter({'a': []}) True (it is not recommended to use non-int values, but it is allowed), which was not the case before your change. Maybe you could try if v != 0 and k not in self:

@serhiy-storchaka
Copy link
Member

serhiy-storchaka commented Jan 16, 2024

@keithasaurus, please add this case in tests.

# so now we can just check that any keys that
# aren't in self are equal to zero
for k, v in other.items():
if v and k not in self:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if v and k not in self:
if v != 0 and k not in self:

@keithasaurus
Copy link
Contributor Author

@HarryLHW good catch. I did the following:

  1. wrote the test
  2. ran tests (tests failed)
  3. made the suggested change in the __eq__ method
  4. ran tests (tests passed)

@serhiy-storchaka serhiy-storchaka self-requested a review January 17, 2024 11:26
@serhiy-storchaka serhiy-storchaka dismissed their stale review January 17, 2024 11:26

Too many changes.

@@ -788,7 +788,21 @@ def __eq__(self, other):
'True if all counts agree. Missing counts are treated as zero.'
if not isinstance(other, Counter):
return NotImplemented
return all(self[e] == other[e] for c in (self, other) for e in c)

if super().__eq__(other):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__eq__() can return NotImplemented, and NotImplemented should not be used in boolean context (it emits a warning).

If you say that dict.__eq__() never returns NotImplemented for arguments that are instances of dict subclasses, than use dict.__eq__().

I wonder what effect this optimization has on comparison of two large Counters with equal size but slightly different keys?

c1 = Counter(large_dict)
c2 = Counter(large_dict)
c1[new_k1] = 0
c2[new_k2] = v  # may be 0 or not 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you say that dict.eq() never returns NotImplemented for arguments that are instances of dict subclasses, than use dict.eq().

Because of the return NotImplemented for non-Counters, I believe we only need to be concerned with the case that dict.__eq__ works for two Counter instances. I'm not able to think of a situation in which this would fail.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like:

class A(dict):
    def __eq__(s, o): return NotImplemented

class B(Counter, A): pass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a little test script. Please feel free to probe further if you have other ideas. Both main and this branch return the same results.

from collections import Counter


class A(dict):
    def __eq__(s, o): return NotImplemented

class B(Counter, A):
    pass

class C(A, Counter):
    pass


a = A()
b = B()
c = C()

print(a == b)
print(b == a)
print(c == b)
print(b == c)
print(a == c)
print(c == a)

print({} == a)
print(a == {})

print({} == b)
print(b == {})

print({} == c)
print(c == {})

result

False
False
True
True
False
False
True
True
True
True
True
True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> print(c == b)
/home/serhiy/py/cpython/Lib/collections/__init__.py:792: DeprecationWarning: NotImplemented should not be used in a boolean context
  if super().__eq__(other):
True
>>> print(b == c)
/home/serhiy/py/cpython/Lib/collections/__init__.py:792: DeprecationWarning: NotImplemented should not be used in a boolean context
  if super().__eq__(other):
True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The simple solution would be just to check is True rather than the implicit bool. Does that meaningfully alter the performance gains here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AA-Turner No, it's essentially unchanged, so I added the explicit is True. Results from the benchmark here after the change are:

0: 0.09299182891845703
1: 0.20043516159057617
2: 0.08538293838500977
3: 0.7197530269622803
4: 0.719200611114502
5: 0.017891883850097656
total duration: 1.8356554508209229

This is presumably because is True is only done once per call -- it's not part of any iteration.

@serhiy-storchaka
Copy link
Member

Please benchmark also with large Counters (thousands or tens of thousands of items, or more). Do it also with tuple keys (they have more expensive hash and comparison).

@keithasaurus
Copy link
Contributor Author

keithasaurus commented Jan 18, 2024

Please benchmark also with large Counters (thousands or tens of thousands of items, or more). Do it also with tuple keys (they have more expensive hash and comparison).

The results of this are in line with what we've seen in the above benchmarks (code below).

main branch

0: 1.058445930480957
1: 0.26082491874694824
2: 1.021928071975708
3: 1.0234870910644531
4: 1.0238280296325684
5: 0.24341487884521484
total duration: 4.63192892074585

this branch

0: 0.1014411449432373
1: 0.20003199577331543
2: 0.08758807182312012
3: 0.7278151512145996
4: 0.730457067489624
5: 0.0179901123046875
total duration: 1.865323543548584

The code:

from collections import Counter
from copy import copy
from typing import Any
from time import time

import urllib.request
with urllib.request.urlopen('https://www.usconstitution.net/const.txt') as f:
    text = f.read().decode('utf-8')

constitution_counter = Counter(text.split())
constitution_counter_1 = copy(constitution_counter)
constitution_counter_1["word_not_in_counter"] = 1
constitution_zero_equal = copy(constitution_counter)
constitution_zero_equal["word_not_in_counter"] = 0

tuple_counter = Counter()
split_text = text.split()[:2000]
for i, word in enumerate(split_text):
    try:
        tuple_counter[(word, split_text[i+1])] += 1
    except IndexError:
        tuple_counter[(word,)] += 1

tuple_counter_1 = copy(tuple_counter)
tuple_counter_1[(split_text[998], split_text[999])] += 1


class CounterSubClass(Counter):
    pass


subclass_counter = CounterSubClass()
non_subclass_counter = Counter()
for word in split_text[:1000]:
    subclass_counter[word] += 1
    non_subclass_counter[word] += 1


def benchmark(iterations: int, a: Counter[Any], b: Counter[Any], expected_result: bool) -> float:
    assert (a == b) is expected_result

    start = time()
    for i in range(iterations):
        a == b
    return time() - start


total_duration = 0.0
for i, (a, b, expected_result) in enumerate([
    (tuple_counter, tuple_counter, True),
    # unequal at the end
    (tuple_counter, tuple_counter_1, False),
    # long equal
    (constitution_counter, constitution_counter, True),
    # long equal but will super.__eq__ will return False
    (constitution_counter, constitution_zero_equal, True),
    # long unequal (at the very end)
    (constitution_counter, constitution_counter_1, False),
    # equivalence between subclasses using __dict__.eq
    (subclass_counter, non_subclass_counter, True)
]):
    result = benchmark(5000, a, b, expected_result)
    print(f"{i}:", result)
    total_duration += result

print("total duration:", total_duration)

@rhettinger
Copy link
Contributor

rhettinger commented Jan 19, 2024

FWIW, I greatly prefer the current code which is self-evidently correct, easy to maintain, and easy to explain to others.

Also, the super call can't be predicted in advance where it will go or what is will do because it is the MRO of a subclass that determines that actual search order.

If the small speed-up comes from removing the generator expression, I am dubious becauae that feels like just chasing the current state of CPython where the performance is in flux and the situation may reverse in the future as other optimizations kick in. Over the long-term, the cleanest, shortest, clearest, and most idiomatic code tends to win.

I recommend that we not do this. We all have limits on how unclean we're willing the make code to save a few clock cycles. In this case, I'm not willing to forgo the current succinct clear code. This is doubly true because in practice this method is hardly ever used, and it is unlikely that any real program would ever benefit.

@keithasaurus
Copy link
Contributor Author

keithasaurus commented Jan 19, 2024

@rhettinger I agree with a lot of the points. The first code change I proposed was simpler because I thought it was more maintainable:

def __eq__(self, other):
    if not isinstance(other, Counter):
        return NotImplemented

    for obj1, obj2 in ((self, other), (other, self)):
        for k, v in obj1.items():
             if v != obj2[k]:
                  return False
    return True

To my understanding this is faster because of primarily because of removing the generator, as you mentioned, but also because it does fewer dict lookups -- iterating over .items() instead of the keys. The only real point I differ on is that I think this is easier to read than multi-level generator expressions. I benchmarked this at ~1.45x performance gain.

I just pushed up a version of the code that is the simpler original PR code, but also includes the dict.__eq__ check explicitly -- instead of using super().__eq__() -- to resolve the MRO concern. This is ~2.7x speedup for the latest benchmark code above -- the main difference is that the exact equality check with dict.__eq__ is up to 12x as fast as the current code. Does this strike you as a better balance of maintainability / performance?

This is doubly true because in practice this method is hardly ever used, and it is unlikely that any real program would ever benefit.

__eq__ is also used as part of __gt__ and __lt__. Do you think those are similarly rare, because I was considering refactoring __le__ and __ge__, which are also used by the respective __lt__ and __gt__ methods.

@rhettinger
Copy link
Contributor

rhettinger commented Jan 19, 2024

My preference is to leave the code in its current state which is clean, simple, explainable, and reasonably performant.

The only real point I differ on is that I think this is easier to read than multi-level generator expressions.

I don't see any issue with a nested-for but it really bugs you, it is easy to replace it with itertools.chain:

- return all(self[e] == other[e] for c in (self, other) for e in c)
+ return all(self[e] == other[e] for e in chain(self, other))

@rhettinger
Copy link
Contributor

Thanks for the suggestion, but I will decline. This isn't a speed-critical method that warrants heroic efforts.

I greatly prefer the current code which reads like a definition of what the __eq__ test should do. It is easy to maintain, to explain to others, and check parallels to the other rich comparisons.

The current code only depends on __iter__*() and __getitem__() and will respect any overrides to those methods made a by a subclasser. I don't think that a subclasser should also have to change items(). The dict.__eq__ fast path is only faster if the dicts are identical; otherwise, it slows down the not-equal case and the equal-but-for-zero-count cases.

@rhettinger rhettinger closed this Feb 15, 2024
@pochmann3
Copy link
Contributor

pochmann3 commented Feb 29, 2024

If the small speed-up comes from removing the generator expression, I am dubious becauae that feels like just chasing the current state of CPython where the performance is in flux and the situation may reverse in the future as other optimizations kick in. Over the long-term, the cleanest, shortest, clearest, and most idiomatic code tends to win.

I was curious how much of the speed-up from "removing the generator expression" comes from no longer having the generator pointlessly send True values to all instead of filtering them out. Looks like that can be most of it, and we can avoid it in the generator, and I highly doubt that CPython will ever do that optimization on its own.

The three candidates:

def generator(xs):
    return all(x for x in xs)

def loop(xs):
    for x in xs:
        if not x:
            return False
    return True

def generator_no_pointless_Trues(xs):
    return all(False for x in xs if not x)

Benchmark with xs = [True] * 1000, the optimized generator gets very close to the loop:

  9.6 ± 0.0 μs  loop
 10.0 ± 0.0 μs  generator_no_pointless_Trues
 30.3 ± 0.1 μs  generator

Python: 3.12.0 (main, Oct  7 2023, 10:42:35) [GCC 13.2.1 20230801]

That said, it's less extreme if there are fewer values or if there's an early false value. And of course readability suffers, and I prefer the non-optimized generator here .

Benchmark script

Attempt This Online!

def generator(xs):
    return all(x for x in xs)

def loop(xs):
    for x in xs:
        if not x:
            return False
    return True

def generator_no_pointless_Trues(xs):
    return all(False for x in xs if not x)

funcs = [generator, loop, generator_no_pointless_Trues]

from timeit import timeit
from statistics import mean, stdev
import sys
import random

# Correctness
xss = [[]]
for _ in range(5):
    for xs in xss:
        assert len({f(xs) for f in funcs}) == 1
    xss = [xs + [b] for xs in xss for b in (False, True)]

# Speed
xs = [True] * 1000
times = {f: [] for f in funcs}
def stats(f):
    ts = [t * 1e6 for t in sorted(times[f])[:10]]
    return f' {mean(ts):4.1f} ± {stdev(ts):3.1f} μs '
for _ in range(100):
    random.shuffle(funcs)
    for f in funcs:
        t = timeit(lambda: f(xs), number=100) / 100
        times[f].append(t)
for f in sorted(funcs, key=stats):
    print(stats(f), f.__name__)

print('\nPython:', sys.version)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants