diff --git a/tests/snippets/stdlib_itertools.py b/tests/snippets/stdlib_itertools.py index 117ccb938..4249814b6 100644 --- a/tests/snippets/stdlib_itertools.py +++ b/tests/snippets/stdlib_itertools.py @@ -301,7 +301,6 @@ assert list(t[0]) == [1,2,3] assert list(t[0]) == [] # itertools.product - it = itertools.product([1, 2], [3, 4]) assert (1, 3) == next(it) assert (1, 4) == next(it) @@ -324,10 +323,12 @@ with assert_raises(TypeError): itertools.product([1, 2], repeat=None) # itertools.combinations - it = itertools.combinations([1, 2, 3, 4], 2) assert list(it) == [(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)] +it = itertools.combinations([1, 2, 3], 0) +assert list(it) == [()] + it = itertools.combinations([1, 2, 3], 1) assert list(it) == [(1,), (2,), (3,)] @@ -345,6 +346,9 @@ with assert_raises(StopIteration): with assert_raises(ValueError): itertools.combinations([1, 2, 3, 4], -2) +with assert_raises(TypeError): + itertools.combinations([1, 2, 3, 4], None) + # itertools.zip_longest tests zl = itertools.zip_longest assert list(zl(['a', 'b', 'c'], range(3), [9, 8, 7])) \ diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 9ac099ff0..ae7fa12a1 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -61,7 +61,7 @@ Basically reference counting, but then done by rust. /// to the python object by 1. pub type PyObjectRef = Rc>; -/// Use this type for function which return a python object or and exception. +/// Use this type for functions which return a python object or an exception. /// Both the python object and the python exception are `PyObjectRef` types /// since exceptions are also python objects. pub type PyResult = Result; // A valid value, or an exception diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index f114d8076..cc7ada2b5 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -314,7 +314,7 @@ impl PyItertoolsTakewhile { return Err(new_stop_iteration(vm)); } - // might be StopIteration or anything else, which is propaged upwwards + // might be StopIteration or anything else, which is propagated upwards let obj = call_next(vm, &self.iterable)?; let predicate = &self.predicate; @@ -908,45 +908,40 @@ impl PyItertoolsCombinations { let n = self.pool.len(); let r = self.r.get(); + if r == 0 { + self.exhausted.set(true); + return Ok(vm.ctx.new_tuple(vec![])); + } + let res = PyTuple::from( - self.pool + self.indices + .borrow() .iter() - .enumerate() - .filter(|(idx, _)| self.indices.borrow().contains(&idx)) - .map(|(_, num)| num.clone()) + .map(|&i| self.pool[i].clone()) .collect::>(), ); let mut indices = self.indices.borrow_mut(); - let mut sentinel = false; // Scan indices right-to-left until finding one that is not at its maximum (i + n - r). - let mut idx = r - 1; - loop { - if indices[idx] != idx + n - r { - sentinel = true; - break; - } - - if idx != 0 { - idx -= 1; - } else { - break; - } + let mut idx = r as isize - 1; + while idx >= 0 && indices[idx as usize] == idx as usize + n - r { + idx -= 1; } + // If no suitable index is found, then the indices are all at // their maximum value and we're done. - if !sentinel { + if idx < 0 { self.exhausted.set(true); - } - - // Increment the current index which we know is not at its - // maximum. Then move back to the right setting each index - // to its lowest possible value (one higher than the index - // to its left -- this maintains the sort order invariant). - indices[idx] += 1; - for j in idx + 1..r { - indices[j] = indices[j - 1] + 1; + } else { + // Increment the current index which we know is not at its + // maximum. Then move back to the right setting each index + // to its lowest possible value (one higher than the index + // to its left -- this maintains the sort order invariant). + indices[idx as usize] += 1; + for j in idx as usize + 1..r { + indices[j] = indices[j - 1] + 1; + } } Ok(res.into_ref(vm).into_object())