diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index 05ab7bdaa..3c55688da 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -1449,8 +1449,11 @@ class ByteArrayTest(BaseBytesTest, unittest.TestCase): self.assertEqual(b, b1) self.assertIs(b, b1) - # TODO: RUSTPYTHON - @unittest.expectedFailure + # NOTE: RUSTPYTHON: + # + # The second instance of self.assertGreater was replaced with + # self.assertGreaterEqual since, in RustPython, the underlying storage + # is a Vec which doesn't require trailing null byte. def test_alloc(self): b = bytearray() alloc = b.__alloc__() @@ -1459,12 +1462,15 @@ class ByteArrayTest(BaseBytesTest, unittest.TestCase): for i in range(100): b += b"x" alloc = b.__alloc__() - self.assertGreater(alloc, len(b)) # including trailing null byte + self.assertGreaterEqual(alloc, len(b)) # NOTE: RUSTPYTHON patched if alloc not in seq: seq.append(alloc) - # TODO: RUSTPYTHON - @unittest.expectedFailure + # NOTE: RUSTPYTHON: + # + # The usages of self.assertGreater were replaced with + # self.assertGreaterEqual since, in RustPython, the underlying storage + # is a Vec which doesn't require trailing null byte. def test_init_alloc(self): b = bytearray() def g(): @@ -1475,12 +1481,12 @@ class ByteArrayTest(BaseBytesTest, unittest.TestCase): self.assertEqual(len(b), len(a)) self.assertLessEqual(len(b), i) alloc = b.__alloc__() - self.assertGreater(alloc, len(b)) # including trailing null byte + self.assertGreaterEqual(alloc, len(b)) # NOTE: RUSTPYTHON patched b.__init__(g()) self.assertEqual(list(b), list(range(1, 100))) self.assertEqual(len(b), 99) alloc = b.__alloc__() - self.assertGreater(alloc, len(b)) + self.assertGreaterEqual(alloc, len(b)) # NOTE: RUSTPYTHON patched def test_extend(self): orig = b'hello' @@ -1999,8 +2005,6 @@ class ByteArraySubclassTest(SubclassTest, unittest.TestCase): basetype = bytearray type2test = ByteArraySubclass - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_init_override(self): class subclass(bytearray): def __init__(me, newarg=1, *args, **kwargs): diff --git a/vm/src/builtins/bytearray.rs b/vm/src/builtins/bytearray.rs index 905c14304..6c5baaa38 100644 --- a/vm/src/builtins/bytearray.rs +++ b/vm/src/builtins/bytearray.rs @@ -18,7 +18,7 @@ use crate::common::lock::{ PyMappedRwLockReadGuard, PyMappedRwLockWriteGuard, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, }; -use crate::function::{OptionalArg, OptionalOption}; +use crate::function::{FuncArgs, OptionalArg, OptionalOption}; use crate::sliceable::{PySliceableSequence, PySliceableSequenceMut, SequenceIndex}; use crate::slots::{ BufferProtocol, Comparable, Hashable, Iterable, PyComparisonOp, PyIter, Unhashable, @@ -45,7 +45,7 @@ use std::mem::size_of; /// - any object implementing the buffer API.\n \ /// - an integer"; #[pyclass(module = false, name = "bytearray")] -#[derive(Debug)] +#[derive(Debug, Default)] pub struct PyByteArray { inner: PyRwLock, exports: AtomicCell, @@ -102,12 +102,16 @@ pub(crate) fn init(context: &PyContext) { #[pyimpl(flags(BASETYPE), with(Hashable, Comparable, BufferProtocol, Iterable))] impl PyByteArray { #[pyslot] - fn tp_new( - cls: PyTypeRef, - options: ByteInnerNewOptions, - vm: &VirtualMachine, - ) -> PyResult> { - options.get_bytearray(cls, vm) + fn tp_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult> { + PyByteArray::default().into_ref_with_type(vm, cls) + } + + #[pymethod(magic)] + fn init(&self, options: ByteInnerNewOptions, vm: &VirtualMachine) -> PyResult<()> { + // First unpack bytearray and *then* get a lock to set it. + let mut inner = options.get_bytearray_inner(vm)?; + std::mem::swap(&mut *self.inner_mut(), &mut inner); + Ok(()) } #[inline] @@ -124,6 +128,11 @@ impl PyByteArray { self.inner().repr("bytearray(", ")") } + #[pymethod(magic)] + fn alloc(&self) -> usize { + self.inner().capacity() + } + #[pymethod(name = "__len__")] fn len(&self) -> usize { self.borrow_buf().len() diff --git a/vm/src/bytesinner.rs b/vm/src/bytesinner.rs index 875cd188a..e0904ff8a 100644 --- a/vm/src/bytesinner.rs +++ b/vm/src/bytesinner.rs @@ -4,7 +4,7 @@ use num_bigint::BigInt; use num_traits::ToPrimitive; use crate::anystr::{self, AnyStr, AnyStrContainer, AnyStrWrapper}; -use crate::builtins::bytearray::{PyByteArray, PyByteArrayRef}; +use crate::builtins::bytearray::PyByteArray; use crate::builtins::bytes::{PyBytes, PyBytesRef}; use crate::builtins::int::{PyInt, PyIntRef}; use crate::builtins::pystr::{self, PyStr, PyStrRef}; @@ -120,12 +120,8 @@ impl ByteInnerNewOptions { PyBytes::from(inner).into_ref_with_type(vm, cls) } - pub fn get_bytearray( - mut self, - cls: PyTypeRef, - vm: &VirtualMachine, - ) -> PyResult { - let inner = if let OptionalArg::Present(source) = self.source.take() { + pub fn get_bytearray_inner(mut self, vm: &VirtualMachine) -> PyResult { + if let OptionalArg::Present(source) = self.source.take() { match_class!(match source { s @ PyStr => Self::get_value_from_string(s, self.encoding, self.errors, vm), i @ PyInt => { @@ -139,9 +135,7 @@ impl ByteInnerNewOptions { }) } else { self.check_args(vm).map(|_| vec![].into()) - }?; - - PyByteArray::from(inner).into_ref_with_type(vm, cls) + } } } @@ -293,6 +287,11 @@ impl PyBytesInner { self.elements.len() } + #[inline] + pub fn capacity(&self) -> usize { + self.elements.capacity() + } + #[inline] pub fn is_empty(&self) -> bool { self.elements.is_empty()