Fix itertools.count to take PyNumber instead of PyInt (#3822)

This commit is contained in:
oow214
2022-06-27 23:02:45 +09:00
committed by GitHub
parent 174c026727
commit 6ad0e547f5

View File

@@ -11,14 +11,13 @@ mod decl {
convert::ToPyObject,
function::{ArgCallable, FuncArgs, OptionalArg, OptionalOption, PosArgs},
identifier,
protocol::{PyIter, PyIterReturn},
protocol::{PyIter, PyIterReturn, PyNumber},
stdlib::sys,
types::{Constructor, IterNext, IterNextIterable},
AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, PyWeakRef, VirtualMachine,
};
use crossbeam_utils::atomic::AtomicCell;
use num_bigint::BigInt;
use num_traits::{One, Signed, ToPrimitive, Zero};
use num_traits::{Signed, ToPrimitive};
use std::fmt;
#[pyattr]
@@ -174,14 +173,14 @@ mod decl {
#[pyclass(name = "count")]
#[derive(Debug, PyPayload)]
struct PyItertoolsCount {
cur: PyRwLock<BigInt>,
step: BigInt,
cur: PyRwLock<PyObjectRef>,
step: PyIntRef,
}
#[derive(FromArgs)]
struct CountNewArgs {
#[pyarg(positional, optional)]
start: OptionalArg<PyIntRef>,
start: OptionalArg<PyObjectRef>,
#[pyarg(positional, optional)]
step: OptionalArg<PyIntRef>,
@@ -195,14 +194,11 @@ mod decl {
Self::Args { start, step }: Self::Args,
vm: &VirtualMachine,
) -> PyResult {
let start = match start.into_option() {
Some(int) => int.as_bigint().clone(),
None => BigInt::zero(),
};
let step = match step.into_option() {
Some(int) => int.as_bigint().clone(),
None => BigInt::one(),
};
let start: PyObjectRef = start.into_option().unwrap_or_else(|| vm.new_pyobj(0));
let step: PyIntRef = step.into_option().unwrap_or_else(|| vm.new_pyref(1));
if !PyNumber::check(&start, vm) {
return Err(vm.new_value_error("a number is require".to_owned()));
}
PyItertoolsCount {
cur: PyRwLock::new(start),
@@ -219,7 +215,7 @@ mod decl {
// if (lz->cnt == PY_SSIZE_T_MAX)
// return Py_BuildValue("0(00)", Py_TYPE(lz), lz->long_cnt, lz->long_step);
#[pymethod(magic)]
fn reduce(zelf: PyRef<Self>) -> (PyTypeRef, (BigInt,)) {
fn reduce(zelf: PyRef<Self>) -> (PyTypeRef, (PyObjectRef,)) {
(zelf.class().clone(), (zelf.cur.read().clone(),))
}
@@ -234,8 +230,9 @@ mod decl {
impl IterNext for PyItertoolsCount {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let mut cur = zelf.cur.write();
let step = zelf.step.clone();
let result = cur.clone();
*cur += &zelf.step;
*cur = vm._iadd(&*cur, step.as_object())?;
Ok(PyIterReturn::Return(result.to_pyobject(vm)))
}
}