forked from Rust-related/RustPython
Fix a bunch of random tests (#5533)
This commit is contained in:
@@ -6,73 +6,37 @@ pub(crate) use _random::make_module;
|
||||
mod _random {
|
||||
use crate::common::lock::PyMutex;
|
||||
use crate::vm::{
|
||||
builtins::{PyInt, PyTypeRef},
|
||||
builtins::{PyInt, PyTupleRef},
|
||||
convert::ToPyException,
|
||||
function::OptionalOption,
|
||||
types::Constructor,
|
||||
PyObjectRef, PyPayload, PyResult, VirtualMachine,
|
||||
types::{Constructor, Initializer},
|
||||
PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use malachite_bigint::{BigInt, BigUint, Sign};
|
||||
use mt19937::MT19937;
|
||||
use num_traits::{Signed, Zero};
|
||||
use rand::{rngs::StdRng, RngCore, SeedableRng};
|
||||
|
||||
#[derive(Debug)]
|
||||
enum PyRng {
|
||||
Std(Box<StdRng>),
|
||||
MT(Box<mt19937::MT19937>),
|
||||
}
|
||||
|
||||
impl Default for PyRng {
|
||||
fn default() -> Self {
|
||||
PyRng::Std(Box::new(StdRng::from_os_rng()))
|
||||
}
|
||||
}
|
||||
|
||||
impl RngCore for PyRng {
|
||||
fn next_u32(&mut self) -> u32 {
|
||||
match self {
|
||||
Self::Std(s) => s.next_u32(),
|
||||
Self::MT(m) => m.next_u32(),
|
||||
}
|
||||
}
|
||||
fn next_u64(&mut self) -> u64 {
|
||||
match self {
|
||||
Self::Std(s) => s.next_u64(),
|
||||
Self::MT(m) => m.next_u64(),
|
||||
}
|
||||
}
|
||||
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||
match self {
|
||||
Self::Std(s) => s.fill_bytes(dest),
|
||||
Self::MT(m) => m.fill_bytes(dest),
|
||||
}
|
||||
}
|
||||
}
|
||||
use rand::{RngCore, SeedableRng};
|
||||
use rustpython_vm::types::DefaultConstructor;
|
||||
|
||||
#[pyattr]
|
||||
#[pyclass(name = "Random")]
|
||||
#[derive(Debug, PyPayload)]
|
||||
#[derive(Debug, PyPayload, Default)]
|
||||
struct PyRandom {
|
||||
rng: PyMutex<PyRng>,
|
||||
rng: PyMutex<MT19937>,
|
||||
}
|
||||
|
||||
impl Constructor for PyRandom {
|
||||
type Args = OptionalOption<PyObjectRef>;
|
||||
impl DefaultConstructor for PyRandom {}
|
||||
|
||||
fn py_new(
|
||||
cls: PyTypeRef,
|
||||
// TODO: use x as the seed.
|
||||
_x: Self::Args,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult {
|
||||
PyRandom {
|
||||
rng: PyMutex::default(),
|
||||
}
|
||||
.into_ref_with_type(vm, cls)
|
||||
.map(Into::into)
|
||||
impl Initializer for PyRandom {
|
||||
type Args = OptionalOption;
|
||||
|
||||
fn init(zelf: PyRef<Self>, x: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
|
||||
zelf.seed(x, vm)
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(flags(BASETYPE), with(Constructor))]
|
||||
#[pyclass(flags(BASETYPE), with(Constructor, Initializer))]
|
||||
impl PyRandom {
|
||||
#[pymethod]
|
||||
fn random(&self) -> f64 {
|
||||
@@ -82,9 +46,8 @@ mod _random {
|
||||
|
||||
#[pymethod]
|
||||
fn seed(&self, n: OptionalOption<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
|
||||
let new_rng = n
|
||||
.flatten()
|
||||
.map(|n| {
|
||||
*self.rng.lock() = match n.flatten() {
|
||||
Some(n) => {
|
||||
// Fallback to using hash if object isn't Int-like.
|
||||
let (_, mut key) = match n.downcast::<PyInt>() {
|
||||
Ok(n) => n.as_bigint().abs(),
|
||||
@@ -95,27 +58,21 @@ mod _random {
|
||||
key.reverse();
|
||||
}
|
||||
let key = if key.is_empty() { &[0] } else { key.as_slice() };
|
||||
Ok(PyRng::MT(Box::new(mt19937::MT19937::new_with_slice_seed(
|
||||
key,
|
||||
))))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
|
||||
*self.rng.lock() = new_rng;
|
||||
MT19937::new_with_slice_seed(key)
|
||||
}
|
||||
None => MT19937::try_from_os_rng()
|
||||
.map_err(|e| std::io::Error::from(e).to_pyexception(vm))?,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn getrandbits(&self, k: isize, vm: &VirtualMachine) -> PyResult<BigInt> {
|
||||
match k {
|
||||
k if k < 0 => {
|
||||
Err(vm.new_value_error("number of bits must be non-negative".to_owned()))
|
||||
}
|
||||
..0 => Err(vm.new_value_error("number of bits must be non-negative".to_owned())),
|
||||
0 => Ok(BigInt::zero()),
|
||||
_ => {
|
||||
mut k => {
|
||||
let mut rng = self.rng.lock();
|
||||
let mut k = k;
|
||||
let mut gen_u32 = |k| {
|
||||
let r = rng.next_u32();
|
||||
if k < 32 {
|
||||
@@ -145,5 +102,40 @@ mod _random {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn getstate(&self, vm: &VirtualMachine) -> PyTupleRef {
|
||||
let rng = self.rng.lock();
|
||||
vm.new_tuple(
|
||||
rng.get_state()
|
||||
.iter()
|
||||
.copied()
|
||||
.chain([rng.get_index() as u32])
|
||||
.map(|i| vm.ctx.new_int(i).into())
|
||||
.collect::<Vec<PyObjectRef>>(),
|
||||
)
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn setstate(&self, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
|
||||
let state: &[_; mt19937::N + 1] = state
|
||||
.as_slice()
|
||||
.try_into()
|
||||
.map_err(|_| vm.new_value_error("state vector is the wrong size".to_owned()))?;
|
||||
let (index, state) = state.split_last().unwrap();
|
||||
let index: usize = index.try_to_value(vm)?;
|
||||
if index > mt19937::N {
|
||||
return Err(vm.new_value_error("invalid state".to_owned()));
|
||||
}
|
||||
let state: [u32; mt19937::N] = state
|
||||
.iter()
|
||||
.map(|i| i.try_to_value(vm))
|
||||
.process_results(|it| it.collect_array())?
|
||||
.unwrap();
|
||||
let mut rng = self.rng.lock();
|
||||
rng.set_state(&state);
|
||||
rng.set_index(index);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user