diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py index 35412f5be..95a0100c0 100644 --- a/Lib/test/test_marshal.py +++ b/Lib/test/test_marshal.py @@ -101,8 +101,6 @@ class StringTestCase(unittest.TestCase, HelperMixin): for s in ["", "Andr\xe8 Previn", "abc", " "*10000]: self.helper(s) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bytes(self): for s in [b"", b"Andr\xe8 Previn", b"abc", b" "*10000]: self.helper(s) @@ -337,8 +335,6 @@ class BugsTestCase(unittest.TestCase): self.assertRaises(ValueError, marshal.load, BadReader(marshal.dumps(value))) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_eof(self): data = marshal.dumps(("hello", "dolly", None)) for i in range(len(data)): diff --git a/vm/src/stdlib/marshal.rs b/vm/src/stdlib/marshal.rs index ae1cfa27b..957c95c40 100644 --- a/vm/src/stdlib/marshal.rs +++ b/vm/src/stdlib/marshal.rs @@ -14,19 +14,17 @@ mod decl { protocol::PyBuffer, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; - /// TODO - /// PyBytes: Currently getting recursion error with match_class! use num_bigint::{BigInt, Sign}; use num_traits::Zero; #[repr(u8)] enum Type { // Null = b'0', - // None = b'N', + None = b'N', False = b'F', True = b'T', // StopIter = b'S', - // Ellipsis = b'.', + Ellipsis = b'.', Int = b'i', Float = b'g', // Complex = b'y', @@ -56,11 +54,11 @@ mod decl { use Type::*; Ok(match value { // b'0' => Null, - // b'N' => None, + b'N' => None, b'F' => False, b'T' => True, // b'S' => StopIter, - // b'.' => Ellipsis, + b'.' => Ellipsis, b'i' => Int, b'g' => Float, // b'y' => Complex, @@ -109,80 +107,92 @@ mod decl { /// Dumping helper function to turn a value into bytes. fn dump_obj(buf: &mut Vec, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - match_class!(match value { - pyint @ PyInt => { - if pyint.class().is(vm.ctx.types.bool_type) { - let typ = if pyint.as_bigint().is_zero() { - Type::False + if vm.is_none(&value) { + buf.push(Type::None as u8); + } else if value.is(&vm.ctx.ellipsis) { + buf.push(Type::Ellipsis as u8); + } else { + match_class!(match value { + pyint @ PyInt => { + if pyint.class().is(vm.ctx.types.bool_type) { + let typ = if pyint.as_bigint().is_zero() { + Type::False + } else { + Type::True + }; + buf.push(typ as u8); } else { - Type::True - }; - buf.push(typ as u8); - } else { - buf.push(Type::Int as u8); - let (sign, int_bytes) = pyint.as_bigint().to_bytes_le(); - let mut len = int_bytes.len() as i32; - if sign == Sign::Minus { - len = -len; + buf.push(Type::Int as u8); + let (sign, int_bytes) = pyint.as_bigint().to_bytes_le(); + let mut len = int_bytes.len() as i32; + if sign == Sign::Minus { + len = -len; + } + buf.extend(len.to_le_bytes()); + buf.extend(int_bytes); } - buf.extend(len.to_le_bytes()); - buf.extend(int_bytes); } - } - pyfloat @ PyFloat => { - buf.push(Type::Float as u8); - buf.extend(pyfloat.to_f64().to_le_bytes()); - } - pystr @ PyStr => { - buf.push(Type::Str as u8); - write_size(buf, pystr.as_str().len(), vm)?; - buf.extend(pystr.as_str().as_bytes()); - } - pylist @ PyList => { - buf.push(Type::List as u8); - let pylist_items = pylist.borrow_vec(); - dump_seq(buf, pylist_items.iter(), vm)?; - } - pyset @ PySet => { - buf.push(Type::Set as u8); - let elements = pyset.elements(); - dump_seq(buf, elements.iter(), vm)?; - } - pyfrozen @ PyFrozenSet => { - buf.push(Type::FrozenSet as u8); - let elements = pyfrozen.elements(); - dump_seq(buf, elements.iter(), vm)?; - } - pytuple @ PyTuple => { - buf.push(Type::Tuple as u8); - dump_seq(buf, pytuple.iter(), vm)?; - } - pydict @ PyDict => { - buf.push(Type::Dict as u8); - write_size(buf, pydict.len(), vm)?; - for (key, value) in pydict { - dump_obj(buf, key, vm)?; - dump_obj(buf, value, vm)?; + pyfloat @ PyFloat => { + buf.push(Type::Float as u8); + buf.extend(pyfloat.to_f64().to_le_bytes()); } - } - bytes @ PyByteArray => { - buf.push(Type::Bytes as u8); - let data = bytes.borrow_buf(); - write_size(buf, data.len(), vm)?; - buf.extend(&*data); - } - co @ PyCode => { - buf.push(Type::Code as u8); - let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes(); - write_size(buf, bytes.len(), vm)?; - buf.extend(bytes); - } - _ => { - return Err(vm.new_not_implemented_error( - "TODO: not implemented yet or marshal unsupported type".to_owned(), - )); - } - }); + pystr @ PyStr => { + buf.push(Type::Str as u8); + write_size(buf, pystr.as_str().len(), vm)?; + buf.extend(pystr.as_str().as_bytes()); + } + pylist @ PyList => { + buf.push(Type::List as u8); + let pylist_items = pylist.borrow_vec(); + dump_seq(buf, pylist_items.iter(), vm)?; + } + pyset @ PySet => { + buf.push(Type::Set as u8); + let elements = pyset.elements(); + dump_seq(buf, elements.iter(), vm)?; + } + pyfrozen @ PyFrozenSet => { + buf.push(Type::FrozenSet as u8); + let elements = pyfrozen.elements(); + dump_seq(buf, elements.iter(), vm)?; + } + pytuple @ PyTuple => { + buf.push(Type::Tuple as u8); + dump_seq(buf, pytuple.iter(), vm)?; + } + pydict @ PyDict => { + buf.push(Type::Dict as u8); + write_size(buf, pydict.len(), vm)?; + for (key, value) in pydict { + dump_obj(buf, key, vm)?; + dump_obj(buf, value, vm)?; + } + } + bytes @ PyBytes => { + buf.push(Type::Bytes as u8); + let data = bytes.as_bytes(); + write_size(buf, data.len(), vm)?; + buf.extend(&*data); + } + bytes @ PyByteArray => { + buf.push(Type::Bytes as u8); + let data = bytes.borrow_buf(); + write_size(buf, data.len(), vm)?; + buf.extend(&*data); + } + co @ PyCode => { + buf.push(Type::Code as u8); + let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes(); + write_size(buf, bytes.len(), vm)?; + buf.extend(bytes); + } + _ => { + return Err(vm.new_not_implemented_error( + "TODO: not implemented yet or marshal unsupported type".to_owned(), + )); + } + }) + } Ok(()) } @@ -248,8 +258,10 @@ mod decl { let typ = Type::try_from(*type_indicator) .map_err(|_| vm.new_value_error("bad marshal data (unknown type code)".to_owned()))?; let (obj, buf) = match typ { - Type::True => ((true).to_pyobject(vm), buf), - Type::False => ((false).to_pyobject(vm), buf), + Type::True => (true.to_pyobject(vm), buf), + Type::False => (false.to_pyobject(vm), buf), + Type::None => (vm.ctx.none(), buf), + Type::Ellipsis => (vm.ctx.ellipsis(), buf), Type::Int => { if buf.len() < 4 { return Err(too_short_error(vm));