diff --git a/Cargo.lock b/Cargo.lock index c9c78ffe6..881852168 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1640,6 +1640,7 @@ dependencies = [ "socket2", "statrs", "thread_local", + "timsort", "uname", "unic-char-property", "unic-normal", @@ -1984,6 +1985,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "timsort" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a035368bf2997adda2f78c1e2683e2b8e8f584e349e8cb7f8adf294b0a4ad70" + [[package]] name = "tinyvec" version = "0.3.4" diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py index 6049b1f80..59b5c7227 100644 --- a/Lib/test/test_builtin.py +++ b/Lib/test/test_builtin.py @@ -873,7 +873,6 @@ class BuiltinTest(unittest.TestCase): m2 = map(map_char, "Is this the real life?") self.check_iter_pickle(m1, list(m2), proto) - @unittest.skip("TODO: RUSTPYTHON") def test_max(self): self.assertEqual(max('123123'), '3') self.assertEqual(max(1, 2, 3), 3) diff --git a/Lib/test/test_sort.py b/Lib/test/test_sort.py new file mode 100644 index 000000000..312d8a635 --- /dev/null +++ b/Lib/test/test_sort.py @@ -0,0 +1,390 @@ +from test import support +import random +import unittest +from functools import cmp_to_key + +verbose = support.verbose +nerrors = 0 + + +def check(tag, expected, raw, compare=None): + global nerrors + + if verbose: + print(" checking", tag) + + orig = raw[:] # save input in case of error + if compare: + raw.sort(key=cmp_to_key(compare)) + else: + raw.sort() + + if len(expected) != len(raw): + print("error in", tag) + print("length mismatch;", len(expected), len(raw)) + print(expected) + print(orig) + print(raw) + nerrors += 1 + return + + for i, good in enumerate(expected): + maybe = raw[i] + if good is not maybe: + print("error in", tag) + print("out of order at index", i, good, maybe) + print(expected) + print(orig) + print(raw) + nerrors += 1 + return + +class TestBase(unittest.TestCase): + def testStressfully(self): + # Try a variety of sizes at and around powers of 2, and at powers of 10. + sizes = [0] + for power in range(1, 10): + n = 2 ** power + sizes.extend(range(n-1, n+2)) + sizes.extend([10, 100, 1000]) + + class Complains(object): + maybe_complain = True + + def __init__(self, i): + self.i = i + + def __lt__(self, other): + if Complains.maybe_complain and random.random() < 0.001: + if verbose: + print(" complaining at", self, other) + raise RuntimeError + return self.i < other.i + + def __repr__(self): + return "Complains(%d)" % self.i + + class Stable(object): + def __init__(self, key, i): + self.key = key + self.index = i + + def __lt__(self, other): + return self.key < other.key + + def __repr__(self): + return "Stable(%d, %d)" % (self.key, self.index) + + for n in sizes: + x = list(range(n)) + if verbose: + print("Testing size", n) + + s = x[:] + check("identity", x, s) + + s = x[:] + s.reverse() + check("reversed", x, s) + + s = x[:] + random.shuffle(s) + check("random permutation", x, s) + + y = x[:] + y.reverse() + s = x[:] + check("reversed via function", y, s, lambda a, b: (b>a)-(b= 2: + def bad_key(x): + raise RuntimeError + s = x[:] + self.assertRaises(RuntimeError, s.sort, key=bad_key) + + x = [Complains(i) for i in x] + s = x[:] + random.shuffle(s) + Complains.maybe_complain = True + it_complained = False + try: + s.sort() + except RuntimeError: + it_complained = True + if it_complained: + Complains.maybe_complain = False + check("exception during sort left some permutation", x, s) + + s = [Stable(random.randrange(10), i) for i in range(n)] + augmented = [(e, e.index) for e in s] + augmented.sort() # forced stable because ties broken by index + x = [e for e, i in augmented] # a stable sort of s + check("stability", x, s) + +#============================================================================== + +class TestBugs(unittest.TestCase): + + @unittest.skip("TODO: RUSTPYTHON; figure out how to detect sort mutation that doesn't change list length") + def test_bug453523(self): + # bug 453523 -- list.sort() crasher. + # If this fails, the most likely outcome is a core dump. + # Mutations during a list sort should raise a ValueError. + + class C: + def __lt__(self, other): + if L and random.random() < 0.75: + L.pop() + else: + L.append(3) + return random.random() < 0.5 + + L = [C() for i in range(50)] + self.assertRaises(ValueError, L.sort) + + @unittest.skip("TODO: RUSTPYTHON; figure out how to detect sort mutation that doesn't change list length") + def test_undetected_mutation(self): + # Python 2.4a1 did not always detect mutation + memorywaster = [] + for i in range(20): + def mutating_cmp(x, y): + L.append(3) + L.pop() + return (x > y) - (x < y) + L = [1,2] + self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp)) + def mutating_cmp(x, y): + L.append(3) + del L[:] + return (x > y) - (x < y) + self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp)) + memorywaster = [memorywaster] + +#============================================================================== + +class TestDecorateSortUndecorate(unittest.TestCase): + + def test_decorated(self): + data = 'The quick Brown fox Jumped over The lazy Dog'.split() + copy = data[:] + random.shuffle(data) + data.sort(key=str.lower) + def my_cmp(x, y): + xlower, ylower = x.lower(), y.lower() + return (xlower > ylower) - (xlower < ylower) + copy.sort(key=cmp_to_key(my_cmp)) + + def test_baddecorator(self): + data = 'The quick Brown fox Jumped over The lazy Dog'.split() + self.assertRaises(TypeError, data.sort, key=lambda x,y: 0) + + def test_stability(self): + data = [(random.randrange(100), i) for i in range(200)] + copy = data[:] + data.sort(key=lambda t: t[0]) # sort on the random first field + copy.sort() # sort using both fields + self.assertEqual(data, copy) # should get the same result + + def test_key_with_exception(self): + # Verify that the wrapper has been removed + data = list(range(-2, 2)) + dup = data[:] + self.assertRaises(ZeroDivisionError, data.sort, key=lambda x: 1/x) + self.assertEqual(data, dup) + + def test_key_with_mutation(self): + data = list(range(10)) + def k(x): + del data[:] + data[:] = range(20) + return x + self.assertRaises(ValueError, data.sort, key=k) + + @unittest.skip("TODO: RUSTPYTHON; destructors") + def test_key_with_mutating_del(self): + data = list(range(10)) + class SortKiller(object): + def __init__(self, x): + pass + def __del__(self): + del data[:] + data[:] = range(20) + def __lt__(self, other): + return id(self) < id(other) + self.assertRaises(ValueError, data.sort, key=SortKiller) + + @unittest.skip("TODO: RUSTPYTHON; destructors") + def test_key_with_mutating_del_and_exception(self): + data = list(range(10)) + ## dup = data[:] + class SortKiller(object): + def __init__(self, x): + if x > 2: + raise RuntimeError + def __del__(self): + del data[:] + data[:] = list(range(20)) + self.assertRaises(RuntimeError, data.sort, key=SortKiller) + ## major honking subtlety: we *can't* do: + ## + ## self.assertEqual(data, dup) + ## + ## because there is a reference to a SortKiller in the + ## traceback and by the time it dies we're outside the call to + ## .sort() and so the list protection gimmicks are out of + ## date (this cost some brain cells to figure out...). + + def test_reverse(self): + data = list(range(100)) + random.shuffle(data) + data.sort(reverse=True) + self.assertEqual(data, list(range(99,-1,-1))) + + def test_reverse_stability(self): + data = [(random.randrange(100), i) for i in range(200)] + copy1 = data[:] + copy2 = data[:] + def my_cmp(x, y): + x0, y0 = x[0], y[0] + return (x0 > y0) - (x0 < y0) + def my_cmp_reversed(x, y): + x0, y0 = x[0], y[0] + return (y0 > x0) - (y0 < x0) + data.sort(key=cmp_to_key(my_cmp), reverse=True) + copy1.sort(key=cmp_to_key(my_cmp_reversed)) + self.assertEqual(data, copy1) + copy2.sort(key=lambda x: x[0], reverse=True) + self.assertEqual(data, copy2) + +#============================================================================== +def check_against_PyObject_RichCompareBool(self, L): + ## The idea here is to exploit the fact that unsafe_tuple_compare uses + ## PyObject_RichCompareBool for the second elements of tuples. So we have, + ## for (most) L, sorted(L) == [y[1] for y in sorted([(0,x) for x in L])] + ## This will work as long as __eq__ => not __lt__ for all the objects in L, + ## which holds for all the types used below. + ## + ## Testing this way ensures that the optimized implementation remains consistent + ## with the naive implementation, even if changes are made to any of the + ## richcompares. + ## + ## This function tests sorting for three lists (it randomly shuffles each one): + ## 1. L + ## 2. [(x,) for x in L] + ## 3. [((x,),) for x in L] + + random.seed(0) + random.shuffle(L) + L_1 = L[:] + L_2 = [(x,) for x in L] + L_3 = [((x,),) for x in L] + for L in [L_1, L_2, L_3]: + optimized = sorted(L) + reference = [y[1] for y in sorted([(0,x) for x in L])] + for (opt, ref) in zip(optimized, reference): + self.assertIs(opt, ref) + #note: not assertEqual! We want to ensure *identical* behavior. + +class TestOptimizedCompares(unittest.TestCase): + def test_safe_object_compare(self): + heterogeneous_lists = [[0, 'foo'], + [0.0, 'foo'], + [('foo',), 'foo']] + for L in heterogeneous_lists: + self.assertRaises(TypeError, L.sort) + self.assertRaises(TypeError, [(x,) for x in L].sort) + self.assertRaises(TypeError, [((x,),) for x in L].sort) + + float_int_lists = [[1,1.1], + [1<<70,1.1], + [1.1,1], + [1.1,1<<70]] + for L in float_int_lists: + check_against_PyObject_RichCompareBool(self, L) + + # XXX RUSTPYTHON: added by us but it seems like an implementation detail + @support.cpython_only + def test_unsafe_object_compare(self): + + # This test is by ppperry. It ensures that unsafe_object_compare is + # verifying ms->key_richcompare == tp->richcompare before comparing. + + class WackyComparator(int): + def __lt__(self, other): + elem.__class__ = WackyList2 + return int.__lt__(self, other) + + class WackyList1(list): + pass + + class WackyList2(list): + def __lt__(self, other): + raise ValueError + + L = [WackyList1([WackyComparator(i), i]) for i in range(10)] + elem = L[-1] + with self.assertRaises(ValueError): + L.sort() + + L = [WackyList1([WackyComparator(i), i]) for i in range(10)] + elem = L[-1] + with self.assertRaises(ValueError): + [(x,) for x in L].sort() + + # The following test is also by ppperry. It ensures that + # unsafe_object_compare handles Py_NotImplemented appropriately. + class PointlessComparator: + def __lt__(self, other): + return NotImplemented + L = [PointlessComparator(), PointlessComparator()] + self.assertRaises(TypeError, L.sort) + self.assertRaises(TypeError, [(x,) for x in L].sort) + + # The following tests go through various types that would trigger + # ms->key_compare = unsafe_object_compare + lists = [list(range(100)) + [(1<<70)], + [str(x) for x in range(100)] + ['\uffff'], + [bytes(x) for x in range(100)], + [cmp_to_key(lambda x,y: x (x,) < (x,) + # + # Note that we don't have to put anything in tuples here, because + # the check function does a tuple test automatically. + + check_against_PyObject_RichCompareBool(self, [float('nan')]*100) + check_against_PyObject_RichCompareBool(self, [float('nan') for + _ in range(100)]) + + def test_not_all_tuples(self): + self.assertRaises(TypeError, [(1.0, 1.0), (False, "A"), 6].sort) + self.assertRaises(TypeError, [('a', 1), (1, 'a')].sort) + self.assertRaises(TypeError, [(1, 'a'), ('a', 1)].sort) +#============================================================================== + +if __name__ == "__main__": + unittest.main() diff --git a/vm/Cargo.toml b/vm/Cargo.toml index 6f2d27f2e..2d43fc09d 100644 --- a/vm/Cargo.toml +++ b/vm/Cargo.toml @@ -76,6 +76,7 @@ crossbeam-utils = "0.7" parking_lot = "0.11" thread_local = "1.0" cfg-if = "0.1.10" +timsort = "0.1" ## unicode stuff unicode_names2 = "0.4" diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index d5076e522..699210b3f 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -24,6 +24,7 @@ mod decl { use crate::obj::objfunction::PyFunctionRef; use crate::obj::objint::{self, PyIntRef}; use crate::obj::objiter; + use crate::obj::objlist::{PyList, SortOptions}; use crate::obj::objsequence; use crate::obj::objstr::{PyString, PyStringRef}; use crate::obj::objtype::{self, PyClassRef}; @@ -162,14 +163,14 @@ mod decl { } #[pyfunction] - fn dir(obj: OptionalArg, vm: &VirtualMachine) -> PyResult { + fn dir(obj: OptionalArg, vm: &VirtualMachine) -> PyResult { let seq = match obj { OptionalArg::Present(obj) => vm.call_method(&obj, "__dir__", vec![])?, OptionalArg::Missing => { vm.call_method(&vm.get_locals().into_object(), "keys", vec![])? } }; - let sorted = sorted(vm, PyFuncArgs::new(vec![seq], vec![]))?; + let sorted = sorted(seq, Default::default(), vm)?; Ok(sorted) } @@ -452,41 +453,32 @@ mod decl { } }; - if candidates.is_empty() { - return default - .ok_or_else(|| vm.new_value_error("max() arg is an empty sequence".to_owned())); - } - - // Start with first assumption: let mut candidates_iter = candidates.into_iter(); - let mut x = candidates_iter.next().unwrap(); - // TODO: this key function looks pretty duplicate. Maybe we can create - // a local function? - let mut x_key = if let Some(ref f) = &key_func { - if vm.is_none(f) { - x.clone() - } else { - vm.invoke(f, vec![x.clone()])? + let mut x = match candidates_iter.next() { + Some(x) => x, + None => { + return default + .ok_or_else(|| vm.new_value_error("max() arg is an empty sequence".to_owned())) } - } else { - x.clone() }; - for y in candidates_iter { - let y_key = if let Some(ref f) = &key_func { - if vm.is_none(f) { - y.clone() - } else { - vm.invoke(f, vec![y.clone()])? + let key_func = key_func.filter(|f| !vm.is_none(f)); + if let Some(ref key_func) = key_func { + let mut x_key = vm.invoke(key_func, x.clone())?; + for y in candidates_iter { + let y_key = vm.invoke(key_func, y.clone())?; + let y_gt_x = objbool::boolval(vm, vm._gt(y_key.clone(), x_key.clone())?)?; + if y_gt_x { + x = y; + x_key = y_key; + } + } + } else { + for y in candidates_iter { + let y_gt_x = objbool::boolval(vm, vm._gt(y.clone(), x.clone())?)?; + if y_gt_x { + x = y; } - } else { - y.clone() - }; - let order = vm._gt(x_key.clone(), y_key.clone())?; - - if !objbool::get_value(&order) { - x = y.clone(); - x_key = y_key; } } @@ -753,14 +745,10 @@ mod decl { // builtin_slice #[pyfunction] - fn sorted(vm: &VirtualMachine, mut args: PyFuncArgs) -> PyResult { - let iterable = args - .take_positional() - .ok_or_else(|| vm.new_type_error("sorted expected 1 arguments, got 0".to_string()))?; + fn sorted(iterable: PyObjectRef, opts: SortOptions, vm: &VirtualMachine) -> PyResult { let items = vm.extract_elements(&iterable)?; - let lst = vm.ctx.new_list(items); - - vm.call_method(&lst, "sort", args)?; + let lst = PyList::from(items); + lst.sort(opts, vm)?; Ok(lst) } diff --git a/vm/src/obj/objlist.rs b/vm/src/obj/objlist.rs index 8bcdfc596..fad440d0c 100644 --- a/vm/src/obj/objlist.rs +++ b/vm/src/obj/objlist.rs @@ -92,8 +92,8 @@ impl PyList { } } -#[derive(FromArgs)] -struct SortOptions { +#[derive(FromArgs, Default)] +pub(crate) struct SortOptions { #[pyarg(keyword_only, default = "None")] key: Option, #[pyarg(keyword_only, default = "false")] @@ -729,13 +729,14 @@ impl PyList { } #[pymethod] - fn sort(&self, options: SortOptions, vm: &VirtualMachine) -> PyResult<()> { + pub(crate) fn sort(&self, options: SortOptions, vm: &VirtualMachine) -> PyResult<()> { // replace list contents with [] for duration of sort. // this prevents keyfunc from messing with the list and makes it easy to // check if it tries to append elements to it. let mut elements = std::mem::take(self.borrow_value_mut().deref_mut()); - do_sort(vm, &mut elements, options.key, options.reverse)?; + let res = do_sort(vm, &mut elements, options.key, options.reverse); std::mem::swap(self.borrow_value_mut().deref_mut(), &mut elements); + res?; if !elements.is_empty() { return Err(vm.new_value_error("list modified during sort".to_owned())); @@ -760,66 +761,29 @@ impl PyList { } } -fn quicksort( - vm: &VirtualMachine, - keys: &mut [PyObjectRef], - values: &mut [PyObjectRef], -) -> PyResult<()> { - let len = values.len(); - if len >= 2 { - let pivot = partition(vm, keys, values)?; - quicksort(vm, &mut keys[0..pivot], &mut values[0..pivot])?; - quicksort(vm, &mut keys[pivot + 1..len], &mut values[pivot + 1..len])?; - } - Ok(()) -} - -fn partition( - vm: &VirtualMachine, - keys: &mut [PyObjectRef], - values: &mut [PyObjectRef], -) -> PyResult { - let len = values.len(); - let pivot = len / 2; - - values.swap(pivot, len - 1); - keys.swap(pivot, len - 1); - - let mut store_idx = 0; - for i in 0..len - 1 { - let result = vm._lt(keys[i].clone(), keys[len - 1].clone())?; - let boolval = objbool::boolval(vm, result)?; - if boolval { - values.swap(i, store_idx); - keys.swap(i, store_idx); - store_idx += 1; - } - } - - values.swap(store_idx, len - 1); - keys.swap(store_idx, len - 1); - Ok(store_idx) -} - fn do_sort( vm: &VirtualMachine, values: &mut Vec, key_func: Option, reverse: bool, ) -> PyResult<()> { - // build a list of keys. If no keyfunc is provided, it's a copy of the list. - let mut keys: Vec = vec![]; - for x in values.iter() { - keys.push(match &key_func { - None => x.clone(), - Some(ref func) => vm.invoke(func, vec![x.clone()])?, - }); - } + let cmp = if reverse { + VirtualMachine::_lt + } else { + VirtualMachine::_gt + }; + let cmp = + |a: &PyObjectRef, b: &PyObjectRef| objbool::boolval(vm, cmp(vm, a.clone(), b.clone())?); - quicksort(vm, &mut keys, values)?; - - if reverse { - values.reverse(); + if let Some(ref key_func) = key_func { + let mut items = values + .iter() + .map(|x| Ok((x.clone(), vm.invoke(key_func, vec![x.clone()])?))) + .collect::, _>>()?; + timsort::try_sort_by_gt(&mut items, |a, b| cmp(&a.1, &b.1))?; + *values = items.into_iter().map(|(val, _)| val).collect(); + } else { + timsort::try_sort_by_gt(values, cmp)?; } Ok(()) diff --git a/vm/src/stdlib/random.rs b/vm/src/stdlib/random.rs index aab9e6909..77193ec06 100644 --- a/vm/src/stdlib/random.rs +++ b/vm/src/stdlib/random.rs @@ -91,7 +91,8 @@ mod _random { if cfg!(target_endian = "big") { key.reverse(); } - PyRng::MT(Box::new(mt19937::MT19937::new_with_slice_seed(&key))) + let key = if key.is_empty() { &[0] } else { key.as_slice() }; + PyRng::MT(Box::new(mt19937::MT19937::new_with_slice_seed(key))) } };