mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Fix itertools.count to take PyNumber instead of PyInt (#3822)
This commit is contained in:
@@ -11,14 +11,13 @@ mod decl {
|
||||
convert::ToPyObject,
|
||||
function::{ArgCallable, FuncArgs, OptionalArg, OptionalOption, PosArgs},
|
||||
identifier,
|
||||
protocol::{PyIter, PyIterReturn},
|
||||
protocol::{PyIter, PyIterReturn, PyNumber},
|
||||
stdlib::sys,
|
||||
types::{Constructor, IterNext, IterNextIterable},
|
||||
AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, PyWeakRef, VirtualMachine,
|
||||
};
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
use num_bigint::BigInt;
|
||||
use num_traits::{One, Signed, ToPrimitive, Zero};
|
||||
use num_traits::{Signed, ToPrimitive};
|
||||
use std::fmt;
|
||||
|
||||
#[pyattr]
|
||||
@@ -174,14 +173,14 @@ mod decl {
|
||||
#[pyclass(name = "count")]
|
||||
#[derive(Debug, PyPayload)]
|
||||
struct PyItertoolsCount {
|
||||
cur: PyRwLock<BigInt>,
|
||||
step: BigInt,
|
||||
cur: PyRwLock<PyObjectRef>,
|
||||
step: PyIntRef,
|
||||
}
|
||||
|
||||
#[derive(FromArgs)]
|
||||
struct CountNewArgs {
|
||||
#[pyarg(positional, optional)]
|
||||
start: OptionalArg<PyIntRef>,
|
||||
start: OptionalArg<PyObjectRef>,
|
||||
|
||||
#[pyarg(positional, optional)]
|
||||
step: OptionalArg<PyIntRef>,
|
||||
@@ -195,14 +194,11 @@ mod decl {
|
||||
Self::Args { start, step }: Self::Args,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult {
|
||||
let start = match start.into_option() {
|
||||
Some(int) => int.as_bigint().clone(),
|
||||
None => BigInt::zero(),
|
||||
};
|
||||
let step = match step.into_option() {
|
||||
Some(int) => int.as_bigint().clone(),
|
||||
None => BigInt::one(),
|
||||
};
|
||||
let start: PyObjectRef = start.into_option().unwrap_or_else(|| vm.new_pyobj(0));
|
||||
let step: PyIntRef = step.into_option().unwrap_or_else(|| vm.new_pyref(1));
|
||||
if !PyNumber::check(&start, vm) {
|
||||
return Err(vm.new_value_error("a number is require".to_owned()));
|
||||
}
|
||||
|
||||
PyItertoolsCount {
|
||||
cur: PyRwLock::new(start),
|
||||
@@ -219,7 +215,7 @@ mod decl {
|
||||
// if (lz->cnt == PY_SSIZE_T_MAX)
|
||||
// return Py_BuildValue("0(00)", Py_TYPE(lz), lz->long_cnt, lz->long_step);
|
||||
#[pymethod(magic)]
|
||||
fn reduce(zelf: PyRef<Self>) -> (PyTypeRef, (BigInt,)) {
|
||||
fn reduce(zelf: PyRef<Self>) -> (PyTypeRef, (PyObjectRef,)) {
|
||||
(zelf.class().clone(), (zelf.cur.read().clone(),))
|
||||
}
|
||||
|
||||
@@ -234,8 +230,9 @@ mod decl {
|
||||
impl IterNext for PyItertoolsCount {
|
||||
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
|
||||
let mut cur = zelf.cur.write();
|
||||
let step = zelf.step.clone();
|
||||
let result = cur.clone();
|
||||
*cur += &zelf.step;
|
||||
*cur = vm._iadd(&*cur, step.as_object())?;
|
||||
Ok(PyIterReturn::Return(result.to_pyobject(vm)))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user