Fix a bunch of random tests (#5533)

This commit is contained in:
Noa
2025-02-24 17:41:54 -06:00
committed by GitHub
parent b55a55afc7
commit 2721f2de3f
4 changed files with 63 additions and 98 deletions

View File

@@ -6,73 +6,37 @@ pub(crate) use _random::make_module;
mod _random {
use crate::common::lock::PyMutex;
use crate::vm::{
builtins::{PyInt, PyTypeRef},
builtins::{PyInt, PyTupleRef},
convert::ToPyException,
function::OptionalOption,
types::Constructor,
PyObjectRef, PyPayload, PyResult, VirtualMachine,
types::{Constructor, Initializer},
PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
};
use itertools::Itertools;
use malachite_bigint::{BigInt, BigUint, Sign};
use mt19937::MT19937;
use num_traits::{Signed, Zero};
use rand::{rngs::StdRng, RngCore, SeedableRng};
#[derive(Debug)]
enum PyRng {
Std(Box<StdRng>),
MT(Box<mt19937::MT19937>),
}
impl Default for PyRng {
fn default() -> Self {
PyRng::Std(Box::new(StdRng::from_os_rng()))
}
}
impl RngCore for PyRng {
fn next_u32(&mut self) -> u32 {
match self {
Self::Std(s) => s.next_u32(),
Self::MT(m) => m.next_u32(),
}
}
fn next_u64(&mut self) -> u64 {
match self {
Self::Std(s) => s.next_u64(),
Self::MT(m) => m.next_u64(),
}
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
match self {
Self::Std(s) => s.fill_bytes(dest),
Self::MT(m) => m.fill_bytes(dest),
}
}
}
use rand::{RngCore, SeedableRng};
use rustpython_vm::types::DefaultConstructor;
#[pyattr]
#[pyclass(name = "Random")]
#[derive(Debug, PyPayload)]
#[derive(Debug, PyPayload, Default)]
struct PyRandom {
rng: PyMutex<PyRng>,
rng: PyMutex<MT19937>,
}
impl Constructor for PyRandom {
type Args = OptionalOption<PyObjectRef>;
impl DefaultConstructor for PyRandom {}
fn py_new(
cls: PyTypeRef,
// TODO: use x as the seed.
_x: Self::Args,
vm: &VirtualMachine,
) -> PyResult {
PyRandom {
rng: PyMutex::default(),
}
.into_ref_with_type(vm, cls)
.map(Into::into)
impl Initializer for PyRandom {
type Args = OptionalOption;
fn init(zelf: PyRef<Self>, x: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
zelf.seed(x, vm)
}
}
#[pyclass(flags(BASETYPE), with(Constructor))]
#[pyclass(flags(BASETYPE), with(Constructor, Initializer))]
impl PyRandom {
#[pymethod]
fn random(&self) -> f64 {
@@ -82,9 +46,8 @@ mod _random {
#[pymethod]
fn seed(&self, n: OptionalOption<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
let new_rng = n
.flatten()
.map(|n| {
*self.rng.lock() = match n.flatten() {
Some(n) => {
// Fallback to using hash if object isn't Int-like.
let (_, mut key) = match n.downcast::<PyInt>() {
Ok(n) => n.as_bigint().abs(),
@@ -95,27 +58,21 @@ mod _random {
key.reverse();
}
let key = if key.is_empty() { &[0] } else { key.as_slice() };
Ok(PyRng::MT(Box::new(mt19937::MT19937::new_with_slice_seed(
key,
))))
})
.transpose()?
.unwrap_or_default();
*self.rng.lock() = new_rng;
MT19937::new_with_slice_seed(key)
}
None => MT19937::try_from_os_rng()
.map_err(|e| std::io::Error::from(e).to_pyexception(vm))?,
};
Ok(())
}
#[pymethod]
fn getrandbits(&self, k: isize, vm: &VirtualMachine) -> PyResult<BigInt> {
match k {
k if k < 0 => {
Err(vm.new_value_error("number of bits must be non-negative".to_owned()))
}
..0 => Err(vm.new_value_error("number of bits must be non-negative".to_owned())),
0 => Ok(BigInt::zero()),
_ => {
mut k => {
let mut rng = self.rng.lock();
let mut k = k;
let mut gen_u32 = |k| {
let r = rng.next_u32();
if k < 32 {
@@ -145,5 +102,40 @@ mod _random {
}
}
}
#[pymethod]
fn getstate(&self, vm: &VirtualMachine) -> PyTupleRef {
let rng = self.rng.lock();
vm.new_tuple(
rng.get_state()
.iter()
.copied()
.chain([rng.get_index() as u32])
.map(|i| vm.ctx.new_int(i).into())
.collect::<Vec<PyObjectRef>>(),
)
}
#[pymethod]
fn setstate(&self, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
let state: &[_; mt19937::N + 1] = state
.as_slice()
.try_into()
.map_err(|_| vm.new_value_error("state vector is the wrong size".to_owned()))?;
let (index, state) = state.split_last().unwrap();
let index: usize = index.try_to_value(vm)?;
if index > mt19937::N {
return Err(vm.new_value_error("invalid state".to_owned()));
}
let state: [u32; mt19937::N] = state
.iter()
.map(|i| i.try_to_value(vm))
.process_results(|it| it.collect_array())?
.unwrap();
let mut rng = self.rng.lock();
rng.set_state(&state);
rng.set_index(index);
Ok(())
}
}
}