mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Merge pull request #2152 from RustPython/coolreader18/stable-sort
This commit is contained in:
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -1640,6 +1640,7 @@ dependencies = [
|
||||
"socket2",
|
||||
"statrs",
|
||||
"thread_local",
|
||||
"timsort",
|
||||
"uname",
|
||||
"unic-char-property",
|
||||
"unic-normal",
|
||||
@@ -1984,6 +1985,12 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "timsort"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a035368bf2997adda2f78c1e2683e2b8e8f584e349e8cb7f8adf294b0a4ad70"
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "0.3.4"
|
||||
|
||||
@@ -873,7 +873,6 @@ class BuiltinTest(unittest.TestCase):
|
||||
m2 = map(map_char, "Is this the real life?")
|
||||
self.check_iter_pickle(m1, list(m2), proto)
|
||||
|
||||
@unittest.skip("TODO: RUSTPYTHON")
|
||||
def test_max(self):
|
||||
self.assertEqual(max('123123'), '3')
|
||||
self.assertEqual(max(1, 2, 3), 3)
|
||||
|
||||
390
Lib/test/test_sort.py
Normal file
390
Lib/test/test_sort.py
Normal file
@@ -0,0 +1,390 @@
|
||||
from test import support
|
||||
import random
|
||||
import unittest
|
||||
from functools import cmp_to_key
|
||||
|
||||
verbose = support.verbose
|
||||
nerrors = 0
|
||||
|
||||
|
||||
def check(tag, expected, raw, compare=None):
|
||||
global nerrors
|
||||
|
||||
if verbose:
|
||||
print(" checking", tag)
|
||||
|
||||
orig = raw[:] # save input in case of error
|
||||
if compare:
|
||||
raw.sort(key=cmp_to_key(compare))
|
||||
else:
|
||||
raw.sort()
|
||||
|
||||
if len(expected) != len(raw):
|
||||
print("error in", tag)
|
||||
print("length mismatch;", len(expected), len(raw))
|
||||
print(expected)
|
||||
print(orig)
|
||||
print(raw)
|
||||
nerrors += 1
|
||||
return
|
||||
|
||||
for i, good in enumerate(expected):
|
||||
maybe = raw[i]
|
||||
if good is not maybe:
|
||||
print("error in", tag)
|
||||
print("out of order at index", i, good, maybe)
|
||||
print(expected)
|
||||
print(orig)
|
||||
print(raw)
|
||||
nerrors += 1
|
||||
return
|
||||
|
||||
class TestBase(unittest.TestCase):
|
||||
def testStressfully(self):
|
||||
# Try a variety of sizes at and around powers of 2, and at powers of 10.
|
||||
sizes = [0]
|
||||
for power in range(1, 10):
|
||||
n = 2 ** power
|
||||
sizes.extend(range(n-1, n+2))
|
||||
sizes.extend([10, 100, 1000])
|
||||
|
||||
class Complains(object):
|
||||
maybe_complain = True
|
||||
|
||||
def __init__(self, i):
|
||||
self.i = i
|
||||
|
||||
def __lt__(self, other):
|
||||
if Complains.maybe_complain and random.random() < 0.001:
|
||||
if verbose:
|
||||
print(" complaining at", self, other)
|
||||
raise RuntimeError
|
||||
return self.i < other.i
|
||||
|
||||
def __repr__(self):
|
||||
return "Complains(%d)" % self.i
|
||||
|
||||
class Stable(object):
|
||||
def __init__(self, key, i):
|
||||
self.key = key
|
||||
self.index = i
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.key < other.key
|
||||
|
||||
def __repr__(self):
|
||||
return "Stable(%d, %d)" % (self.key, self.index)
|
||||
|
||||
for n in sizes:
|
||||
x = list(range(n))
|
||||
if verbose:
|
||||
print("Testing size", n)
|
||||
|
||||
s = x[:]
|
||||
check("identity", x, s)
|
||||
|
||||
s = x[:]
|
||||
s.reverse()
|
||||
check("reversed", x, s)
|
||||
|
||||
s = x[:]
|
||||
random.shuffle(s)
|
||||
check("random permutation", x, s)
|
||||
|
||||
y = x[:]
|
||||
y.reverse()
|
||||
s = x[:]
|
||||
check("reversed via function", y, s, lambda a, b: (b>a)-(b<a))
|
||||
|
||||
if verbose:
|
||||
print(" Checking against an insane comparison function.")
|
||||
print(" If the implementation isn't careful, this may segfault.")
|
||||
s = x[:]
|
||||
s.sort(key=cmp_to_key(lambda a, b: int(random.random() * 3) - 1))
|
||||
check("an insane function left some permutation", x, s)
|
||||
|
||||
if len(x) >= 2:
|
||||
def bad_key(x):
|
||||
raise RuntimeError
|
||||
s = x[:]
|
||||
self.assertRaises(RuntimeError, s.sort, key=bad_key)
|
||||
|
||||
x = [Complains(i) for i in x]
|
||||
s = x[:]
|
||||
random.shuffle(s)
|
||||
Complains.maybe_complain = True
|
||||
it_complained = False
|
||||
try:
|
||||
s.sort()
|
||||
except RuntimeError:
|
||||
it_complained = True
|
||||
if it_complained:
|
||||
Complains.maybe_complain = False
|
||||
check("exception during sort left some permutation", x, s)
|
||||
|
||||
s = [Stable(random.randrange(10), i) for i in range(n)]
|
||||
augmented = [(e, e.index) for e in s]
|
||||
augmented.sort() # forced stable because ties broken by index
|
||||
x = [e for e, i in augmented] # a stable sort of s
|
||||
check("stability", x, s)
|
||||
|
||||
#==============================================================================
|
||||
|
||||
class TestBugs(unittest.TestCase):
|
||||
|
||||
@unittest.skip("TODO: RUSTPYTHON; figure out how to detect sort mutation that doesn't change list length")
|
||||
def test_bug453523(self):
|
||||
# bug 453523 -- list.sort() crasher.
|
||||
# If this fails, the most likely outcome is a core dump.
|
||||
# Mutations during a list sort should raise a ValueError.
|
||||
|
||||
class C:
|
||||
def __lt__(self, other):
|
||||
if L and random.random() < 0.75:
|
||||
L.pop()
|
||||
else:
|
||||
L.append(3)
|
||||
return random.random() < 0.5
|
||||
|
||||
L = [C() for i in range(50)]
|
||||
self.assertRaises(ValueError, L.sort)
|
||||
|
||||
@unittest.skip("TODO: RUSTPYTHON; figure out how to detect sort mutation that doesn't change list length")
|
||||
def test_undetected_mutation(self):
|
||||
# Python 2.4a1 did not always detect mutation
|
||||
memorywaster = []
|
||||
for i in range(20):
|
||||
def mutating_cmp(x, y):
|
||||
L.append(3)
|
||||
L.pop()
|
||||
return (x > y) - (x < y)
|
||||
L = [1,2]
|
||||
self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
|
||||
def mutating_cmp(x, y):
|
||||
L.append(3)
|
||||
del L[:]
|
||||
return (x > y) - (x < y)
|
||||
self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
|
||||
memorywaster = [memorywaster]
|
||||
|
||||
#==============================================================================
|
||||
|
||||
class TestDecorateSortUndecorate(unittest.TestCase):
|
||||
|
||||
def test_decorated(self):
|
||||
data = 'The quick Brown fox Jumped over The lazy Dog'.split()
|
||||
copy = data[:]
|
||||
random.shuffle(data)
|
||||
data.sort(key=str.lower)
|
||||
def my_cmp(x, y):
|
||||
xlower, ylower = x.lower(), y.lower()
|
||||
return (xlower > ylower) - (xlower < ylower)
|
||||
copy.sort(key=cmp_to_key(my_cmp))
|
||||
|
||||
def test_baddecorator(self):
|
||||
data = 'The quick Brown fox Jumped over The lazy Dog'.split()
|
||||
self.assertRaises(TypeError, data.sort, key=lambda x,y: 0)
|
||||
|
||||
def test_stability(self):
|
||||
data = [(random.randrange(100), i) for i in range(200)]
|
||||
copy = data[:]
|
||||
data.sort(key=lambda t: t[0]) # sort on the random first field
|
||||
copy.sort() # sort using both fields
|
||||
self.assertEqual(data, copy) # should get the same result
|
||||
|
||||
def test_key_with_exception(self):
|
||||
# Verify that the wrapper has been removed
|
||||
data = list(range(-2, 2))
|
||||
dup = data[:]
|
||||
self.assertRaises(ZeroDivisionError, data.sort, key=lambda x: 1/x)
|
||||
self.assertEqual(data, dup)
|
||||
|
||||
def test_key_with_mutation(self):
|
||||
data = list(range(10))
|
||||
def k(x):
|
||||
del data[:]
|
||||
data[:] = range(20)
|
||||
return x
|
||||
self.assertRaises(ValueError, data.sort, key=k)
|
||||
|
||||
@unittest.skip("TODO: RUSTPYTHON; destructors")
|
||||
def test_key_with_mutating_del(self):
|
||||
data = list(range(10))
|
||||
class SortKiller(object):
|
||||
def __init__(self, x):
|
||||
pass
|
||||
def __del__(self):
|
||||
del data[:]
|
||||
data[:] = range(20)
|
||||
def __lt__(self, other):
|
||||
return id(self) < id(other)
|
||||
self.assertRaises(ValueError, data.sort, key=SortKiller)
|
||||
|
||||
@unittest.skip("TODO: RUSTPYTHON; destructors")
|
||||
def test_key_with_mutating_del_and_exception(self):
|
||||
data = list(range(10))
|
||||
## dup = data[:]
|
||||
class SortKiller(object):
|
||||
def __init__(self, x):
|
||||
if x > 2:
|
||||
raise RuntimeError
|
||||
def __del__(self):
|
||||
del data[:]
|
||||
data[:] = list(range(20))
|
||||
self.assertRaises(RuntimeError, data.sort, key=SortKiller)
|
||||
## major honking subtlety: we *can't* do:
|
||||
##
|
||||
## self.assertEqual(data, dup)
|
||||
##
|
||||
## because there is a reference to a SortKiller in the
|
||||
## traceback and by the time it dies we're outside the call to
|
||||
## .sort() and so the list protection gimmicks are out of
|
||||
## date (this cost some brain cells to figure out...).
|
||||
|
||||
def test_reverse(self):
|
||||
data = list(range(100))
|
||||
random.shuffle(data)
|
||||
data.sort(reverse=True)
|
||||
self.assertEqual(data, list(range(99,-1,-1)))
|
||||
|
||||
def test_reverse_stability(self):
|
||||
data = [(random.randrange(100), i) for i in range(200)]
|
||||
copy1 = data[:]
|
||||
copy2 = data[:]
|
||||
def my_cmp(x, y):
|
||||
x0, y0 = x[0], y[0]
|
||||
return (x0 > y0) - (x0 < y0)
|
||||
def my_cmp_reversed(x, y):
|
||||
x0, y0 = x[0], y[0]
|
||||
return (y0 > x0) - (y0 < x0)
|
||||
data.sort(key=cmp_to_key(my_cmp), reverse=True)
|
||||
copy1.sort(key=cmp_to_key(my_cmp_reversed))
|
||||
self.assertEqual(data, copy1)
|
||||
copy2.sort(key=lambda x: x[0], reverse=True)
|
||||
self.assertEqual(data, copy2)
|
||||
|
||||
#==============================================================================
|
||||
def check_against_PyObject_RichCompareBool(self, L):
|
||||
## The idea here is to exploit the fact that unsafe_tuple_compare uses
|
||||
## PyObject_RichCompareBool for the second elements of tuples. So we have,
|
||||
## for (most) L, sorted(L) == [y[1] for y in sorted([(0,x) for x in L])]
|
||||
## This will work as long as __eq__ => not __lt__ for all the objects in L,
|
||||
## which holds for all the types used below.
|
||||
##
|
||||
## Testing this way ensures that the optimized implementation remains consistent
|
||||
## with the naive implementation, even if changes are made to any of the
|
||||
## richcompares.
|
||||
##
|
||||
## This function tests sorting for three lists (it randomly shuffles each one):
|
||||
## 1. L
|
||||
## 2. [(x,) for x in L]
|
||||
## 3. [((x,),) for x in L]
|
||||
|
||||
random.seed(0)
|
||||
random.shuffle(L)
|
||||
L_1 = L[:]
|
||||
L_2 = [(x,) for x in L]
|
||||
L_3 = [((x,),) for x in L]
|
||||
for L in [L_1, L_2, L_3]:
|
||||
optimized = sorted(L)
|
||||
reference = [y[1] for y in sorted([(0,x) for x in L])]
|
||||
for (opt, ref) in zip(optimized, reference):
|
||||
self.assertIs(opt, ref)
|
||||
#note: not assertEqual! We want to ensure *identical* behavior.
|
||||
|
||||
class TestOptimizedCompares(unittest.TestCase):
|
||||
def test_safe_object_compare(self):
|
||||
heterogeneous_lists = [[0, 'foo'],
|
||||
[0.0, 'foo'],
|
||||
[('foo',), 'foo']]
|
||||
for L in heterogeneous_lists:
|
||||
self.assertRaises(TypeError, L.sort)
|
||||
self.assertRaises(TypeError, [(x,) for x in L].sort)
|
||||
self.assertRaises(TypeError, [((x,),) for x in L].sort)
|
||||
|
||||
float_int_lists = [[1,1.1],
|
||||
[1<<70,1.1],
|
||||
[1.1,1],
|
||||
[1.1,1<<70]]
|
||||
for L in float_int_lists:
|
||||
check_against_PyObject_RichCompareBool(self, L)
|
||||
|
||||
# XXX RUSTPYTHON: added by us but it seems like an implementation detail
|
||||
@support.cpython_only
|
||||
def test_unsafe_object_compare(self):
|
||||
|
||||
# This test is by ppperry. It ensures that unsafe_object_compare is
|
||||
# verifying ms->key_richcompare == tp->richcompare before comparing.
|
||||
|
||||
class WackyComparator(int):
|
||||
def __lt__(self, other):
|
||||
elem.__class__ = WackyList2
|
||||
return int.__lt__(self, other)
|
||||
|
||||
class WackyList1(list):
|
||||
pass
|
||||
|
||||
class WackyList2(list):
|
||||
def __lt__(self, other):
|
||||
raise ValueError
|
||||
|
||||
L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
|
||||
elem = L[-1]
|
||||
with self.assertRaises(ValueError):
|
||||
L.sort()
|
||||
|
||||
L = [WackyList1([WackyComparator(i), i]) for i in range(10)]
|
||||
elem = L[-1]
|
||||
with self.assertRaises(ValueError):
|
||||
[(x,) for x in L].sort()
|
||||
|
||||
# The following test is also by ppperry. It ensures that
|
||||
# unsafe_object_compare handles Py_NotImplemented appropriately.
|
||||
class PointlessComparator:
|
||||
def __lt__(self, other):
|
||||
return NotImplemented
|
||||
L = [PointlessComparator(), PointlessComparator()]
|
||||
self.assertRaises(TypeError, L.sort)
|
||||
self.assertRaises(TypeError, [(x,) for x in L].sort)
|
||||
|
||||
# The following tests go through various types that would trigger
|
||||
# ms->key_compare = unsafe_object_compare
|
||||
lists = [list(range(100)) + [(1<<70)],
|
||||
[str(x) for x in range(100)] + ['\uffff'],
|
||||
[bytes(x) for x in range(100)],
|
||||
[cmp_to_key(lambda x,y: x<y)(x) for x in range(100)]]
|
||||
for L in lists:
|
||||
check_against_PyObject_RichCompareBool(self, L)
|
||||
|
||||
def test_unsafe_latin_compare(self):
|
||||
check_against_PyObject_RichCompareBool(self, [str(x) for
|
||||
x in range(100)])
|
||||
|
||||
def test_unsafe_long_compare(self):
|
||||
check_against_PyObject_RichCompareBool(self, [x for
|
||||
x in range(100)])
|
||||
|
||||
def test_unsafe_float_compare(self):
|
||||
check_against_PyObject_RichCompareBool(self, [float(x) for
|
||||
x in range(100)])
|
||||
|
||||
def test_unsafe_tuple_compare(self):
|
||||
# This test was suggested by Tim Peters. It verifies that the tuple
|
||||
# comparison respects the current tuple compare semantics, which do not
|
||||
# guarantee that x < x <=> (x,) < (x,)
|
||||
#
|
||||
# Note that we don't have to put anything in tuples here, because
|
||||
# the check function does a tuple test automatically.
|
||||
|
||||
check_against_PyObject_RichCompareBool(self, [float('nan')]*100)
|
||||
check_against_PyObject_RichCompareBool(self, [float('nan') for
|
||||
_ in range(100)])
|
||||
|
||||
def test_not_all_tuples(self):
|
||||
self.assertRaises(TypeError, [(1.0, 1.0), (False, "A"), 6].sort)
|
||||
self.assertRaises(TypeError, [('a', 1), (1, 'a')].sort)
|
||||
self.assertRaises(TypeError, [(1, 'a'), ('a', 1)].sort)
|
||||
#==============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -76,6 +76,7 @@ crossbeam-utils = "0.7"
|
||||
parking_lot = "0.11"
|
||||
thread_local = "1.0"
|
||||
cfg-if = "0.1.10"
|
||||
timsort = "0.1"
|
||||
|
||||
## unicode stuff
|
||||
unicode_names2 = "0.4"
|
||||
|
||||
@@ -24,6 +24,7 @@ mod decl {
|
||||
use crate::obj::objfunction::PyFunctionRef;
|
||||
use crate::obj::objint::{self, PyIntRef};
|
||||
use crate::obj::objiter;
|
||||
use crate::obj::objlist::{PyList, SortOptions};
|
||||
use crate::obj::objsequence;
|
||||
use crate::obj::objstr::{PyString, PyStringRef};
|
||||
use crate::obj::objtype::{self, PyClassRef};
|
||||
@@ -162,14 +163,14 @@ mod decl {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn dir(obj: OptionalArg<PyObjectRef>, vm: &VirtualMachine) -> PyResult {
|
||||
fn dir(obj: OptionalArg<PyObjectRef>, vm: &VirtualMachine) -> PyResult<PyList> {
|
||||
let seq = match obj {
|
||||
OptionalArg::Present(obj) => vm.call_method(&obj, "__dir__", vec![])?,
|
||||
OptionalArg::Missing => {
|
||||
vm.call_method(&vm.get_locals().into_object(), "keys", vec![])?
|
||||
}
|
||||
};
|
||||
let sorted = sorted(vm, PyFuncArgs::new(vec![seq], vec![]))?;
|
||||
let sorted = sorted(seq, Default::default(), vm)?;
|
||||
Ok(sorted)
|
||||
}
|
||||
|
||||
@@ -452,41 +453,32 @@ mod decl {
|
||||
}
|
||||
};
|
||||
|
||||
if candidates.is_empty() {
|
||||
return default
|
||||
.ok_or_else(|| vm.new_value_error("max() arg is an empty sequence".to_owned()));
|
||||
}
|
||||
|
||||
// Start with first assumption:
|
||||
let mut candidates_iter = candidates.into_iter();
|
||||
let mut x = candidates_iter.next().unwrap();
|
||||
// TODO: this key function looks pretty duplicate. Maybe we can create
|
||||
// a local function?
|
||||
let mut x_key = if let Some(ref f) = &key_func {
|
||||
if vm.is_none(f) {
|
||||
x.clone()
|
||||
} else {
|
||||
vm.invoke(f, vec![x.clone()])?
|
||||
let mut x = match candidates_iter.next() {
|
||||
Some(x) => x,
|
||||
None => {
|
||||
return default
|
||||
.ok_or_else(|| vm.new_value_error("max() arg is an empty sequence".to_owned()))
|
||||
}
|
||||
} else {
|
||||
x.clone()
|
||||
};
|
||||
|
||||
for y in candidates_iter {
|
||||
let y_key = if let Some(ref f) = &key_func {
|
||||
if vm.is_none(f) {
|
||||
y.clone()
|
||||
} else {
|
||||
vm.invoke(f, vec![y.clone()])?
|
||||
let key_func = key_func.filter(|f| !vm.is_none(f));
|
||||
if let Some(ref key_func) = key_func {
|
||||
let mut x_key = vm.invoke(key_func, x.clone())?;
|
||||
for y in candidates_iter {
|
||||
let y_key = vm.invoke(key_func, y.clone())?;
|
||||
let y_gt_x = objbool::boolval(vm, vm._gt(y_key.clone(), x_key.clone())?)?;
|
||||
if y_gt_x {
|
||||
x = y;
|
||||
x_key = y_key;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for y in candidates_iter {
|
||||
let y_gt_x = objbool::boolval(vm, vm._gt(y.clone(), x.clone())?)?;
|
||||
if y_gt_x {
|
||||
x = y;
|
||||
}
|
||||
} else {
|
||||
y.clone()
|
||||
};
|
||||
let order = vm._gt(x_key.clone(), y_key.clone())?;
|
||||
|
||||
if !objbool::get_value(&order) {
|
||||
x = y.clone();
|
||||
x_key = y_key;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -753,14 +745,10 @@ mod decl {
|
||||
// builtin_slice
|
||||
|
||||
#[pyfunction]
|
||||
fn sorted(vm: &VirtualMachine, mut args: PyFuncArgs) -> PyResult {
|
||||
let iterable = args
|
||||
.take_positional()
|
||||
.ok_or_else(|| vm.new_type_error("sorted expected 1 arguments, got 0".to_string()))?;
|
||||
fn sorted(iterable: PyObjectRef, opts: SortOptions, vm: &VirtualMachine) -> PyResult<PyList> {
|
||||
let items = vm.extract_elements(&iterable)?;
|
||||
let lst = vm.ctx.new_list(items);
|
||||
|
||||
vm.call_method(&lst, "sort", args)?;
|
||||
let lst = PyList::from(items);
|
||||
lst.sort(opts, vm)?;
|
||||
Ok(lst)
|
||||
}
|
||||
|
||||
|
||||
@@ -92,8 +92,8 @@ impl PyList {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromArgs)]
|
||||
struct SortOptions {
|
||||
#[derive(FromArgs, Default)]
|
||||
pub(crate) struct SortOptions {
|
||||
#[pyarg(keyword_only, default = "None")]
|
||||
key: Option<PyObjectRef>,
|
||||
#[pyarg(keyword_only, default = "false")]
|
||||
@@ -729,13 +729,14 @@ impl PyList {
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn sort(&self, options: SortOptions, vm: &VirtualMachine) -> PyResult<()> {
|
||||
pub(crate) fn sort(&self, options: SortOptions, vm: &VirtualMachine) -> PyResult<()> {
|
||||
// replace list contents with [] for duration of sort.
|
||||
// this prevents keyfunc from messing with the list and makes it easy to
|
||||
// check if it tries to append elements to it.
|
||||
let mut elements = std::mem::take(self.borrow_value_mut().deref_mut());
|
||||
do_sort(vm, &mut elements, options.key, options.reverse)?;
|
||||
let res = do_sort(vm, &mut elements, options.key, options.reverse);
|
||||
std::mem::swap(self.borrow_value_mut().deref_mut(), &mut elements);
|
||||
res?;
|
||||
|
||||
if !elements.is_empty() {
|
||||
return Err(vm.new_value_error("list modified during sort".to_owned()));
|
||||
@@ -760,66 +761,29 @@ impl PyList {
|
||||
}
|
||||
}
|
||||
|
||||
fn quicksort(
|
||||
vm: &VirtualMachine,
|
||||
keys: &mut [PyObjectRef],
|
||||
values: &mut [PyObjectRef],
|
||||
) -> PyResult<()> {
|
||||
let len = values.len();
|
||||
if len >= 2 {
|
||||
let pivot = partition(vm, keys, values)?;
|
||||
quicksort(vm, &mut keys[0..pivot], &mut values[0..pivot])?;
|
||||
quicksort(vm, &mut keys[pivot + 1..len], &mut values[pivot + 1..len])?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn partition(
|
||||
vm: &VirtualMachine,
|
||||
keys: &mut [PyObjectRef],
|
||||
values: &mut [PyObjectRef],
|
||||
) -> PyResult<usize> {
|
||||
let len = values.len();
|
||||
let pivot = len / 2;
|
||||
|
||||
values.swap(pivot, len - 1);
|
||||
keys.swap(pivot, len - 1);
|
||||
|
||||
let mut store_idx = 0;
|
||||
for i in 0..len - 1 {
|
||||
let result = vm._lt(keys[i].clone(), keys[len - 1].clone())?;
|
||||
let boolval = objbool::boolval(vm, result)?;
|
||||
if boolval {
|
||||
values.swap(i, store_idx);
|
||||
keys.swap(i, store_idx);
|
||||
store_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
values.swap(store_idx, len - 1);
|
||||
keys.swap(store_idx, len - 1);
|
||||
Ok(store_idx)
|
||||
}
|
||||
|
||||
fn do_sort(
|
||||
vm: &VirtualMachine,
|
||||
values: &mut Vec<PyObjectRef>,
|
||||
key_func: Option<PyObjectRef>,
|
||||
reverse: bool,
|
||||
) -> PyResult<()> {
|
||||
// build a list of keys. If no keyfunc is provided, it's a copy of the list.
|
||||
let mut keys: Vec<PyObjectRef> = vec![];
|
||||
for x in values.iter() {
|
||||
keys.push(match &key_func {
|
||||
None => x.clone(),
|
||||
Some(ref func) => vm.invoke(func, vec![x.clone()])?,
|
||||
});
|
||||
}
|
||||
let cmp = if reverse {
|
||||
VirtualMachine::_lt
|
||||
} else {
|
||||
VirtualMachine::_gt
|
||||
};
|
||||
let cmp =
|
||||
|a: &PyObjectRef, b: &PyObjectRef| objbool::boolval(vm, cmp(vm, a.clone(), b.clone())?);
|
||||
|
||||
quicksort(vm, &mut keys, values)?;
|
||||
|
||||
if reverse {
|
||||
values.reverse();
|
||||
if let Some(ref key_func) = key_func {
|
||||
let mut items = values
|
||||
.iter()
|
||||
.map(|x| Ok((x.clone(), vm.invoke(key_func, vec![x.clone()])?)))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
timsort::try_sort_by_gt(&mut items, |a, b| cmp(&a.1, &b.1))?;
|
||||
*values = items.into_iter().map(|(val, _)| val).collect();
|
||||
} else {
|
||||
timsort::try_sort_by_gt(values, cmp)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -91,7 +91,8 @@ mod _random {
|
||||
if cfg!(target_endian = "big") {
|
||||
key.reverse();
|
||||
}
|
||||
PyRng::MT(Box::new(mt19937::MT19937::new_with_slice_seed(&key)))
|
||||
let key = if key.is_empty() { &[0] } else { key.as_slice() };
|
||||
PyRng::MT(Box::new(mt19937::MT19937::new_with_slice_seed(key)))
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user