Fix combinations implementation

This commit is contained in:
Daniel Alley
2019-12-17 11:12:24 -05:00
parent 39ad03c7fb
commit f32fe9dcb8
3 changed files with 30 additions and 31 deletions

View File

@@ -301,7 +301,6 @@ assert list(t[0]) == [1,2,3]
assert list(t[0]) == []
# itertools.product
it = itertools.product([1, 2], [3, 4])
assert (1, 3) == next(it)
assert (1, 4) == next(it)
@@ -324,10 +323,12 @@ with assert_raises(TypeError):
itertools.product([1, 2], repeat=None)
# itertools.combinations
it = itertools.combinations([1, 2, 3, 4], 2)
assert list(it) == [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
it = itertools.combinations([1, 2, 3], 0)
assert list(it) == [()]
it = itertools.combinations([1, 2, 3], 1)
assert list(it) == [(1,), (2,), (3,)]
@@ -345,6 +346,9 @@ with assert_raises(StopIteration):
with assert_raises(ValueError):
itertools.combinations([1, 2, 3, 4], -2)
with assert_raises(TypeError):
itertools.combinations([1, 2, 3, 4], None)
# itertools.zip_longest tests
zl = itertools.zip_longest
assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7])) \

View File

@@ -61,7 +61,7 @@ Basically reference counting, but then done by rust.
/// to the python object by 1.
pub type PyObjectRef = Rc<PyObject<dyn PyObjectPayload>>;
/// Use this type for function which return a python object or and exception.
/// Use this type for functions which return a python object or an exception.
/// Both the python object and the python exception are `PyObjectRef` types
/// since exceptions are also python objects.
pub type PyResult<T = PyObjectRef> = Result<T, PyObjectRef>; // A valid value, or an exception

View File

@@ -314,7 +314,7 @@ impl PyItertoolsTakewhile {
return Err(new_stop_iteration(vm));
}
// might be StopIteration or anything else, which is propaged upwwards
// might be StopIteration or anything else, which is propagated upwards
let obj = call_next(vm, &self.iterable)?;
let predicate = &self.predicate;
@@ -908,45 +908,40 @@ impl PyItertoolsCombinations {
let n = self.pool.len();
let r = self.r.get();
if r == 0 {
self.exhausted.set(true);
return Ok(vm.ctx.new_tuple(vec![]));
}
let res = PyTuple::from(
self.pool
self.indices
.borrow()
.iter()
.enumerate()
.filter(|(idx, _)| self.indices.borrow().contains(&idx))
.map(|(_, num)| num.clone())
.map(|&i| self.pool[i].clone())
.collect::<Vec<PyObjectRef>>(),
);
let mut indices = self.indices.borrow_mut();
let mut sentinel = false;
// Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
let mut idx = r - 1;
loop {
if indices[idx] != idx + n - r {
sentinel = true;
break;
}
if idx != 0 {
idx -= 1;
} else {
break;
}
let mut idx = r as isize - 1;
while idx >= 0 && indices[idx as usize] == idx as usize + n - r {
idx -= 1;
}
// If no suitable index is found, then the indices are all at
// their maximum value and we're done.
if !sentinel {
if idx < 0 {
self.exhausted.set(true);
}
// Increment the current index which we know is not at its
// maximum. Then move back to the right setting each index
// to its lowest possible value (one higher than the index
// to its left -- this maintains the sort order invariant).
indices[idx] += 1;
for j in idx + 1..r {
indices[j] = indices[j - 1] + 1;
} else {
// Increment the current index which we know is not at its
// maximum. Then move back to the right setting each index
// to its lowest possible value (one higher than the index
// to its left -- this maintains the sort order invariant).
indices[idx as usize] += 1;
for j in idx as usize + 1..r {
indices[j] = indices[j - 1] + 1;
}
}
Ok(res.into_ref(vm).into_object())