mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Fix sorted() to use __lt__ instead of __gt__ (#6887)
* test * Fix sorted() to use __lt__ instead of __gt__ CPython's sort uses __lt__ for comparisons, but RustPython was using __gt__. This caused issues when only __lt__ was overridden on a subclass (e.g., NamedTuple with custom __lt__), as it would fall back to the parent class's comparison instead of using the overridden method.
This commit is contained in:
@@ -522,12 +522,17 @@ fn do_sort(
|
||||
key_func: Option<PyObjectRef>,
|
||||
reverse: bool,
|
||||
) -> PyResult<()> {
|
||||
let op = if reverse {
|
||||
PyComparisonOp::Lt
|
||||
} else {
|
||||
PyComparisonOp::Gt
|
||||
// CPython uses __lt__ for all comparisons in sort.
|
||||
// try_sort_by_gt expects is_gt(a, b) = true when a should come AFTER b.
|
||||
let cmp = |a: &PyObjectRef, b: &PyObjectRef| {
|
||||
if reverse {
|
||||
// Descending: a comes after b when a < b
|
||||
a.rich_compare_bool(b, PyComparisonOp::Lt, vm)
|
||||
} else {
|
||||
// Ascending: a comes after b when b < a
|
||||
b.rich_compare_bool(a, PyComparisonOp::Lt, vm)
|
||||
}
|
||||
};
|
||||
let cmp = |a: &PyObjectRef, b: &PyObjectRef| a.rich_compare_bool(b, op, vm);
|
||||
|
||||
if let Some(ref key_func) = key_func {
|
||||
let mut items = values
|
||||
|
||||
@@ -270,6 +270,47 @@ class C:
|
||||
lst.sort(key=C)
|
||||
assert lst == [1, 2, 3, 4, 5]
|
||||
|
||||
# Test that sorted() uses __lt__ (not __gt__) for comparisons.
|
||||
# Track which comparison method is actually called during sort.
|
||||
class TrackComparison:
|
||||
lt_calls = 0
|
||||
gt_calls = 0
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __lt__(self, other):
|
||||
TrackComparison.lt_calls += 1
|
||||
return self.value < other.value
|
||||
|
||||
def __gt__(self, other):
|
||||
TrackComparison.gt_calls += 1
|
||||
return self.value > other.value
|
||||
|
||||
# Reset and test sorted()
|
||||
TrackComparison.lt_calls = 0
|
||||
TrackComparison.gt_calls = 0
|
||||
items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)]
|
||||
sorted(items)
|
||||
assert TrackComparison.lt_calls > 0, "sorted() should call __lt__"
|
||||
assert TrackComparison.gt_calls == 0, f"sorted() should not call __gt__, but it was called {TrackComparison.gt_calls} times"
|
||||
|
||||
# Reset and test list.sort()
|
||||
TrackComparison.lt_calls = 0
|
||||
TrackComparison.gt_calls = 0
|
||||
items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)]
|
||||
items.sort()
|
||||
assert TrackComparison.lt_calls > 0, "list.sort() should call __lt__"
|
||||
assert TrackComparison.gt_calls == 0, f"list.sort() should not call __gt__, but it was called {TrackComparison.gt_calls} times"
|
||||
|
||||
# Reset and test sorted(reverse=True) - should still use __lt__, not __gt__
|
||||
TrackComparison.lt_calls = 0
|
||||
TrackComparison.gt_calls = 0
|
||||
items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)]
|
||||
sorted(items, reverse=True)
|
||||
assert TrackComparison.lt_calls > 0, "sorted(reverse=True) should call __lt__"
|
||||
assert TrackComparison.gt_calls == 0, f"sorted(reverse=True) should not call __gt__, but it was called {TrackComparison.gt_calls} times"
|
||||
|
||||
lst = [5, 1, 2, 3, 4]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user