mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Fix combinations implementation
This commit is contained in:
@@ -301,7 +301,6 @@ assert list(t[0]) == [1,2,3]
|
||||
assert list(t[0]) == []
|
||||
|
||||
# itertools.product
|
||||
|
||||
it = itertools.product([1, 2], [3, 4])
|
||||
assert (1, 3) == next(it)
|
||||
assert (1, 4) == next(it)
|
||||
@@ -324,10 +323,12 @@ with assert_raises(TypeError):
|
||||
itertools.product([1, 2], repeat=None)
|
||||
|
||||
# itertools.combinations
|
||||
|
||||
it = itertools.combinations([1, 2, 3, 4], 2)
|
||||
assert list(it) == [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
|
||||
|
||||
it = itertools.combinations([1, 2, 3], 0)
|
||||
assert list(it) == [()]
|
||||
|
||||
it = itertools.combinations([1, 2, 3], 1)
|
||||
assert list(it) == [(1,), (2,), (3,)]
|
||||
|
||||
@@ -345,6 +346,9 @@ with assert_raises(StopIteration):
|
||||
with assert_raises(ValueError):
|
||||
itertools.combinations([1, 2, 3, 4], -2)
|
||||
|
||||
with assert_raises(TypeError):
|
||||
itertools.combinations([1, 2, 3, 4], None)
|
||||
|
||||
# itertools.zip_longest tests
|
||||
zl = itertools.zip_longest
|
||||
assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7])) \
|
||||
|
||||
@@ -61,7 +61,7 @@ Basically reference counting, but then done by rust.
|
||||
/// to the python object by 1.
|
||||
pub type PyObjectRef = Rc<PyObject<dyn PyObjectPayload>>;
|
||||
|
||||
/// Use this type for function which return a python object or and exception.
|
||||
/// Use this type for functions which return a python object or an exception.
|
||||
/// Both the python object and the python exception are `PyObjectRef` types
|
||||
/// since exceptions are also python objects.
|
||||
pub type PyResult<T = PyObjectRef> = Result<T, PyObjectRef>; // A valid value, or an exception
|
||||
|
||||
@@ -314,7 +314,7 @@ impl PyItertoolsTakewhile {
|
||||
return Err(new_stop_iteration(vm));
|
||||
}
|
||||
|
||||
// might be StopIteration or anything else, which is propaged upwwards
|
||||
// might be StopIteration or anything else, which is propagated upwards
|
||||
let obj = call_next(vm, &self.iterable)?;
|
||||
let predicate = &self.predicate;
|
||||
|
||||
@@ -908,45 +908,40 @@ impl PyItertoolsCombinations {
|
||||
let n = self.pool.len();
|
||||
let r = self.r.get();
|
||||
|
||||
if r == 0 {
|
||||
self.exhausted.set(true);
|
||||
return Ok(vm.ctx.new_tuple(vec![]));
|
||||
}
|
||||
|
||||
let res = PyTuple::from(
|
||||
self.pool
|
||||
self.indices
|
||||
.borrow()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(idx, _)| self.indices.borrow().contains(&idx))
|
||||
.map(|(_, num)| num.clone())
|
||||
.map(|&i| self.pool[i].clone())
|
||||
.collect::<Vec<PyObjectRef>>(),
|
||||
);
|
||||
|
||||
let mut indices = self.indices.borrow_mut();
|
||||
let mut sentinel = false;
|
||||
|
||||
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
|
||||
let mut idx = r - 1;
|
||||
loop {
|
||||
if indices[idx] != idx + n - r {
|
||||
sentinel = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if idx != 0 {
|
||||
idx -= 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
let mut idx = r as isize - 1;
|
||||
while idx >= 0 && indices[idx as usize] == idx as usize + n - r {
|
||||
idx -= 1;
|
||||
}
|
||||
|
||||
// If no suitable index is found, then the indices are all at
|
||||
// their maximum value and we're done.
|
||||
if !sentinel {
|
||||
if idx < 0 {
|
||||
self.exhausted.set(true);
|
||||
}
|
||||
|
||||
// Increment the current index which we know is not at its
|
||||
// maximum. Then move back to the right setting each index
|
||||
// to its lowest possible value (one higher than the index
|
||||
// to its left -- this maintains the sort order invariant).
|
||||
indices[idx] += 1;
|
||||
for j in idx + 1..r {
|
||||
indices[j] = indices[j - 1] + 1;
|
||||
} else {
|
||||
// Increment the current index which we know is not at its
|
||||
// maximum. Then move back to the right setting each index
|
||||
// to its lowest possible value (one higher than the index
|
||||
// to its left -- this maintains the sort order invariant).
|
||||
indices[idx as usize] += 1;
|
||||
for j in idx as usize + 1..r {
|
||||
indices[j] = indices[j - 1] + 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(res.into_ref(vm).into_object())
|
||||
|
||||
Reference in New Issue
Block a user