diff --git a/crates/vm/src/builtins/list.rs b/crates/vm/src/builtins/list.rs index 84825de7d..9c961d292 100644 --- a/crates/vm/src/builtins/list.rs +++ b/crates/vm/src/builtins/list.rs @@ -522,12 +522,17 @@ fn do_sort( key_func: Option, reverse: bool, ) -> PyResult<()> { - let op = if reverse { - PyComparisonOp::Lt - } else { - PyComparisonOp::Gt + // CPython uses __lt__ for all comparisons in sort. + // try_sort_by_gt expects is_gt(a, b) = true when a should come AFTER b. + let cmp = |a: &PyObjectRef, b: &PyObjectRef| { + if reverse { + // Descending: a comes after b when a < b + a.rich_compare_bool(b, PyComparisonOp::Lt, vm) + } else { + // Ascending: a comes after b when b < a + b.rich_compare_bool(a, PyComparisonOp::Lt, vm) + } }; - let cmp = |a: &PyObjectRef, b: &PyObjectRef| a.rich_compare_bool(b, op, vm); if let Some(ref key_func) = key_func { let mut items = values diff --git a/extra_tests/snippets/builtin_list.py b/extra_tests/snippets/builtin_list.py index 3e6bb8fc9..711bf6acc 100644 --- a/extra_tests/snippets/builtin_list.py +++ b/extra_tests/snippets/builtin_list.py @@ -270,6 +270,47 @@ class C: lst.sort(key=C) assert lst == [1, 2, 3, 4, 5] +# Test that sorted() uses __lt__ (not __gt__) for comparisons. +# Track which comparison method is actually called during sort. +class TrackComparison: + lt_calls = 0 + gt_calls = 0 + + def __init__(self, value): + self.value = value + + def __lt__(self, other): + TrackComparison.lt_calls += 1 + return self.value < other.value + + def __gt__(self, other): + TrackComparison.gt_calls += 1 + return self.value > other.value + +# Reset and test sorted() +TrackComparison.lt_calls = 0 +TrackComparison.gt_calls = 0 +items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)] +sorted(items) +assert TrackComparison.lt_calls > 0, "sorted() should call __lt__" +assert TrackComparison.gt_calls == 0, f"sorted() should not call __gt__, but it was called {TrackComparison.gt_calls} times" + +# Reset and test list.sort() +TrackComparison.lt_calls = 0 +TrackComparison.gt_calls = 0 +items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)] +items.sort() +assert TrackComparison.lt_calls > 0, "list.sort() should call __lt__" +assert TrackComparison.gt_calls == 0, f"list.sort() should not call __gt__, but it was called {TrackComparison.gt_calls} times" + +# Reset and test sorted(reverse=True) - should still use __lt__, not __gt__ +TrackComparison.lt_calls = 0 +TrackComparison.gt_calls = 0 +items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)] +sorted(items, reverse=True) +assert TrackComparison.lt_calls > 0, "sorted(reverse=True) should call __lt__" +assert TrackComparison.gt_calls == 0, f"sorted(reverse=True) should not call __gt__, but it was called {TrackComparison.gt_calls} times" + lst = [5, 1, 2, 3, 4]