From 320ed26ab957afc908cb63a7296d2c06df68a636 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sat, 6 Aug 2022 06:18:22 +0900 Subject: [PATCH 1/3] marshal _dumps to take &mut vec --- vm/src/stdlib/marshal.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vm/src/stdlib/marshal.rs b/vm/src/stdlib/marshal.rs index 88c39f012..20739bd8c 100644 --- a/vm/src/stdlib/marshal.rs +++ b/vm/src/stdlib/marshal.rs @@ -18,8 +18,7 @@ mod decl { /// PyBytes: Currently getting recursion error with match_class! use ascii::AsciiStr; use num_bigint::{BigInt, Sign}; - use std::ops::Deref; - use std::slice::Iter; + use std::{ops::Deref, slice::Iter}; const STR_BYTE: u8 = b's'; const INT_BYTE: u8 = b'i'; @@ -49,7 +48,8 @@ mod decl { let mut byte_list = size_to_bytes(pyobjs.len(), vm)?.to_vec(); // For each element, dump into binary, then add its length and value. for element in pyobjs { - let element_bytes: Vec = _dumps(element.clone(), vm)?; + let mut element_bytes = Vec::new(); + _dumps(&mut element_bytes, element.clone(), vm)?; byte_list.extend(size_to_bytes(element_bytes.len(), vm)?); byte_list.extend(element_bytes) } @@ -57,7 +57,7 @@ mod decl { } /// Dumping helper function to turn a value into bytes. - fn _dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult> { + fn _dumps(buf: &mut Vec, value: PyObjectRef, vm: &VirtualMachine) -> PyResult> { let r = match_class!(match value { pyint @ PyInt => { if pyint.class().is(vm.ctx.types.bool_type) { @@ -138,12 +138,15 @@ mod decl { )); } }); + buf.extend(r.as_slice()); Ok(r) } #[pyfunction] fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult { - Ok(PyBytes::from(_dumps(value, vm)?)) + let mut buf = Vec::new(); + _dumps(&mut buf, value, vm)?; + Ok(PyBytes::from(buf)) } #[pyfunction] From a353d5e27ad832cd9061c79b304d7cc581d8308d Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sun, 7 Aug 2022 02:22:53 +0900 Subject: [PATCH 2/3] redesign marshal --- Lib/test/test_marshal.py | 12 -- vm/src/builtins/dict.rs | 4 - vm/src/dictdatatype.rs | 16 -- vm/src/stdlib/marshal.rs | 347 ++++++++++++++++++--------------------- 4 files changed, 163 insertions(+), 216 deletions(-) diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py index 3a5f74041..35412f5be 100644 --- a/Lib/test/test_marshal.py +++ b/Lib/test/test_marshal.py @@ -93,14 +93,10 @@ class FloatTestCase(unittest.TestCase, HelperMixin): n *= 123.4567 class StringTestCase(unittest.TestCase, HelperMixin): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_unicode(self): for s in ["", "Andr\xe8 Previn", "abc", " "*10000]: self.helper(marshal.loads(marshal.dumps(s))) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_string(self): for s in ["", "Andr\xe8 Previn", "abc", " "*10000]: self.helper(s) @@ -159,13 +155,9 @@ class ContainerTestCase(unittest.TestCase, HelperMixin): 'aunicode': "Andr\xe8 Previn" } - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_dict(self): self.helper(self.d) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_list(self): self.helper(list(self.d.items())) @@ -178,8 +170,6 @@ class ContainerTestCase(unittest.TestCase, HelperMixin): class BufferTestCase(unittest.TestCase, HelperMixin): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bytearray(self): b = bytearray(b"abc") self.helper(b) @@ -298,8 +288,6 @@ class BugsTestCase(unittest.TestCase): testString = 'abc' * size marshal.dumps(testString) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_invalid_longs(self): # Issue #7019: marshal.loads shouldn't produce unnormalized PyLongs invalid_string = b'l\x02\x00\x00\x00\x00\x00\x00\x00' diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index 93a4c91f0..9d7f2f282 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -70,10 +70,6 @@ impl PyDict { &self.entries } - pub(crate) fn from_entries(entries: DictContentType) -> Self { - Self { entries } - } - // Used in update and ior. fn merge_object( dict: &DictContentType, diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index d7a5c054d..6dd7e3dd3 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -99,22 +99,6 @@ struct DictEntry { } static_assertions::assert_eq_size!(DictEntry, Option>); -impl DictEntry { - pub(crate) fn as_tuple(&self) -> (PyObjectRef, T) { - (self.key.clone(), self.value.clone()) - } -} - -impl Dict { - pub(crate) fn as_kvpairs(&self) -> Vec<(PyObjectRef, T)> { - let entries = &self.inner.read().entries; - entries - .iter() - .filter_map(|entry| entry.as_ref().map(|dict_entry| dict_entry.as_tuple())) - .collect() - } -} - #[derive(Debug, PartialEq)] pub struct DictSize { indices_size: usize, diff --git a/vm/src/stdlib/marshal.rs b/vm/src/stdlib/marshal.rs index 20739bd8c..9615eec95 100644 --- a/vm/src/stdlib/marshal.rs +++ b/vm/src/stdlib/marshal.rs @@ -4,8 +4,8 @@ pub(crate) use decl::make_module; mod decl { use crate::{ builtins::{ - dict::DictContentType, PyByteArray, PyBytes, PyCode, PyDict, PyFloat, PyFrozenSet, - PyInt, PyList, PySet, PyStr, PyTuple, + PyBaseExceptionRef, PyByteArray, PyBytes, PyCode, PyDict, PyFloat, PyFrozenSet, PyInt, + PyList, PySet, PyStr, PyTuple, }, bytecode, convert::ToPyObject, @@ -16,121 +16,112 @@ mod decl { }; /// TODO /// PyBytes: Currently getting recursion error with match_class! - use ascii::AsciiStr; use num_bigint::{BigInt, Sign}; - use std::{ops::Deref, slice::Iter}; + use num_traits::Zero; const STR_BYTE: u8 = b's'; const INT_BYTE: u8 = b'i'; const FLOAT_BYTE: u8 = b'f'; - const BOOL_BYTE: u8 = b'b'; + const TRUE_BYTE: u8 = b'T'; + const FALSE_BYTE: u8 = b'F'; const LIST_BYTE: u8 = b'['; const TUPLE_BYTE: u8 = b'('; const DICT_BYTE: u8 = b','; const SET_BYTE: u8 = b'~'; const FROZEN_SET_BYTE: u8 = b'<'; const BYTE_ARRAY: u8 = b'>'; + const TYPE_CODE: u8 = b'c'; - /// Safely convert usize to 4 le bytes - fn size_to_bytes(x: usize, vm: &VirtualMachine) -> PyResult<[u8; 4]> { - // For marshalling we want to convert lengths to bytes. To save space - // we limit the size to u32 to keep marshalling smaller. - match u32::try_from(x) { - Ok(n) => Ok(n.to_le_bytes()), - Err(_) => { - Err(vm.new_value_error("Size exceeds 2^32 capacity for marshaling.".to_owned())) - } - } + fn too_short_error(vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_exception_msg( + vm.ctx.exceptions.eof_error.to_owned(), + "marshal data too short".to_owned(), + ) } - /// Dumps a iterator of objects into binary vector. - fn dump_list(pyobjs: Iter, vm: &VirtualMachine) -> PyResult> { - let mut byte_list = size_to_bytes(pyobjs.len(), vm)?.to_vec(); + /// Dumps a sequence of objects into binary vector. + fn dump_seq( + buf: &mut Vec, + iter: std::slice::Iter, + vm: &VirtualMachine, + ) -> PyResult<()> { + write_size(buf, iter.len(), vm)?; // For each element, dump into binary, then add its length and value. - for element in pyobjs { - let mut element_bytes = Vec::new(); - _dumps(&mut element_bytes, element.clone(), vm)?; - byte_list.extend(size_to_bytes(element_bytes.len(), vm)?); - byte_list.extend(element_bytes) + for element in iter { + dump_obj(buf, element.clone(), vm)?; } - Ok(byte_list) + Ok(()) } /// Dumping helper function to turn a value into bytes. - fn _dumps(buf: &mut Vec, value: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - let r = match_class!(match value { + fn dump_obj(buf: &mut Vec, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + match_class!(match value { pyint @ PyInt => { if pyint.class().is(vm.ctx.types.bool_type) { - let (_, mut bool_bytes) = pyint.as_bigint().to_bytes_le(); - bool_bytes.push(BOOL_BYTE); - bool_bytes - } else { - let (sign, mut int_bytes) = pyint.as_bigint().to_bytes_le(); - let sign_byte = match sign { - Sign::Minus => b'-', - Sign::NoSign => b'0', - Sign::Plus => b'+', + let typ = if pyint.as_bigint().is_zero() { + FALSE_BYTE + } else { + TRUE_BYTE }; - // Return as [TYPE, SIGN, uint bytes] - int_bytes.insert(0, sign_byte); - int_bytes.push(INT_BYTE); - int_bytes + buf.push(typ); + } else { + buf.push(INT_BYTE); + let (sign, int_bytes) = pyint.as_bigint().to_bytes_le(); + let mut len = int_bytes.len() as i32; + if sign == Sign::Minus { + len = -len; + } + buf.extend(len.to_le_bytes()); + buf.extend(int_bytes); } } pyfloat @ PyFloat => { - let mut float_bytes = pyfloat.to_f64().to_le_bytes().to_vec(); - float_bytes.push(FLOAT_BYTE); - float_bytes + buf.push(FLOAT_BYTE); + buf.extend(pyfloat.to_f64().to_le_bytes()); } pystr @ PyStr => { - let mut str_bytes = pystr.as_str().as_bytes().to_vec(); - str_bytes.push(STR_BYTE); - str_bytes + buf.push(STR_BYTE); + write_size(buf, pystr.as_str().len(), vm)?; + buf.extend(pystr.as_str().as_bytes()); } pylist @ PyList => { + buf.push(LIST_BYTE); let pylist_items = pylist.borrow_vec(); - let mut list_bytes = dump_list(pylist_items.iter(), vm)?; - list_bytes.push(LIST_BYTE); - list_bytes + dump_seq(buf, pylist_items.iter(), vm)?; } pyset @ PySet => { + buf.push(SET_BYTE); let elements = pyset.elements(); - let mut set_bytes = dump_list(elements.iter(), vm)?; - set_bytes.push(SET_BYTE); - set_bytes + dump_seq(buf, elements.iter(), vm)?; } pyfrozen @ PyFrozenSet => { + buf.push(FROZEN_SET_BYTE); let elements = pyfrozen.elements(); - let mut fset_bytes = dump_list(elements.iter(), vm)?; - fset_bytes.push(FROZEN_SET_BYTE); - fset_bytes + dump_seq(buf, elements.iter(), vm)?; } pytuple @ PyTuple => { - let mut tuple_bytes = dump_list(pytuple.iter(), vm)?; - tuple_bytes.push(TUPLE_BYTE); - tuple_bytes + buf.push(TUPLE_BYTE); + dump_seq(buf, pytuple.iter(), vm)?; } pydict @ PyDict => { - let key_value_pairs = pydict._as_dict_inner().clone().as_kvpairs(); - // Converts list of tuples to PyObjectRefs of tuples - let elements: Vec = key_value_pairs - .into_iter() - .map(|(k, v)| PyTuple::new_ref(vec![k, v], &vm.ctx).to_pyobject(vm)) - .collect(); - // Converts list of tuples to list, dump into binary - let mut dict_bytes = dump_list(elements.iter(), vm)?; - dict_bytes.push(LIST_BYTE); - dict_bytes.push(DICT_BYTE); - dict_bytes + buf.push(DICT_BYTE); + write_size(buf, pydict.len(), vm)?; + for (key, value) in pydict { + dump_obj(buf, key, vm)?; + dump_obj(buf, value, vm)?; + } } - pybyte_array @ PyByteArray => { - let mut pybytes = pybyte_array.borrow_buf_mut(); - pybytes.push(BYTE_ARRAY); - pybytes.deref().to_owned() + bytes @ PyByteArray => { + buf.push(BYTE_ARRAY); + let data = bytes.borrow_buf(); + write_size(buf, data.len(), vm)?; + buf.extend(&*data); } co @ PyCode => { - // Code is default, doesn't have prefix. - co.code.map_clone_bag(&bytecode::BasicBag).to_bytes() + buf.push(TYPE_CODE); + let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes(); + write_size(buf, bytes.len(), vm)?; + buf.extend(bytes); } _ => { return Err(vm.new_not_implemented_error( @@ -138,14 +129,13 @@ mod decl { )); } }); - buf.extend(r.as_slice()); - Ok(r) + Ok(()) } #[pyfunction] fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult { let mut buf = Vec::new(); - _dumps(&mut buf, value, vm)?; + dump_obj(&mut buf, value, vm)?; Ok(PyBytes::from(buf)) } @@ -156,161 +146,150 @@ mod decl { Ok(()) } + /// Safely convert usize to 4 le bytes + fn write_size(buf: &mut Vec, x: usize, vm: &VirtualMachine) -> PyResult<()> { + // For marshalling we want to convert lengths to bytes. To save space + // we limit the size to u32 to keep marshalling smaller. + let n = u32::try_from(x).map_err(|_| { + vm.new_value_error("Size exceeds 2^32 capacity for marshaling.".to_owned()) + })?; + buf.extend(n.to_le_bytes()); + Ok(()) + } + /// Read the next 4 bytes of a slice, read as u32, pass as usize. /// Returns the rest of buffer with the value. - fn eat_length<'a>(bytes: &'a [u8], vm: &VirtualMachine) -> PyResult<(usize, &'a [u8])> { - let (u32_bytes, rest) = bytes.split_at(4); - let length = u32::from_le_bytes(u32_bytes.try_into().map_err(|_| { - vm.new_value_error("Could not read u32 size from byte array".to_owned()) - })?); + fn read_size<'a>(buf: &'a [u8], vm: &VirtualMachine) -> PyResult<(usize, &'a [u8])> { + if buf.len() < 4 { + return Err(too_short_error(vm)); + } + let (u32_bytes, rest) = buf.split_at(4); + let length = u32::from_le_bytes(u32_bytes.try_into().unwrap()); Ok((length as usize, rest)) } - /// Reads next element from a python list. First by getting element size - /// then by building a pybuffer and "loading" the pyobject. - /// Returns rest of buffer with object. - fn next_element_of_list<'a>( - buf: &'a [u8], - vm: &VirtualMachine, - ) -> PyResult<(PyObjectRef, &'a [u8])> { - let (element_length, element_and_rest) = eat_length(buf, vm)?; - let (element_buff, rest) = element_and_rest.split_at(element_length); - let pybuffer = PyBuffer::from_byte_vector(element_buff.to_vec(), vm); - Ok((loads(pybuffer, vm)?, rest)) - } - /// Reads a list (or tuple) from a buffer. - fn read_list(buf: &[u8], vm: &VirtualMachine) -> PyResult> { - let (expected_array_len, mut buffer) = eat_length(buf, vm)?; + fn load_seq<'b>(buf: &'b [u8], vm: &VirtualMachine) -> PyResult<(Vec, &'b [u8])> { + let (len, mut buf) = read_size(buf, vm)?; let mut elements: Vec = Vec::new(); - while !buffer.is_empty() { - let (element, rest_of_buffer) = next_element_of_list(buffer, vm)?; + for _ in 0..len { + let (element, rest) = load_obj(buf, vm)?; + buf = rest; elements.push(element); - buffer = rest_of_buffer; } - debug_assert!(expected_array_len == elements.len()); - Ok(elements) - } - - /// Builds a PyDict from iterator of tuple objects - pub fn from_tuples(iterable: Iter, vm: &VirtualMachine) -> PyResult { - let dict = DictContentType::default(); - for elem in iterable { - let items = match_class!(match elem.clone() { - pytuple @ PyTuple => pytuple.to_vec(), - _ => - return Err(vm.new_value_error( - "Couldn't unmarshal key:value pair of dictionary".to_owned() - )), - }); - // Marshalled tuples are always in format key:value. - dict.insert(vm, &**items.get(0).unwrap(), items.get(1).unwrap().clone())?; - } - Ok(PyDict::from_entries(dict)) + Ok((elements, buf)) } #[pyfunction] fn loads(pybuffer: PyBuffer, vm: &VirtualMachine) -> PyResult { - let full_buff = pybuffer.as_contiguous().ok_or_else(|| { + let buf = pybuffer.as_contiguous().ok_or_else(|| { vm.new_buffer_error("Buffer provided to marshal.loads() is not contiguous".to_owned()) })?; - let (type_indicator, buf) = full_buff.split_last().ok_or_else(|| { - vm.new_exception_msg( - vm.ctx.exceptions.eof_error.to_owned(), - "EOF where object expected.".to_owned(), - ) - })?; - match *type_indicator { - BOOL_BYTE => Ok((buf[0] != 0).to_pyobject(vm)), + let (obj, _) = load_obj(&buf, vm)?; + Ok(obj) + } + + fn load_obj<'b>(buf: &'b [u8], vm: &VirtualMachine) -> PyResult<(PyObjectRef, &'b [u8])> { + let (type_indicator, buf) = buf.split_first().ok_or_else(|| too_short_error(vm))?; + let (obj, buf) = match *type_indicator { + TRUE_BYTE => ((true).to_pyobject(vm), buf), + FALSE_BYTE => ((false).to_pyobject(vm), buf), INT_BYTE => { - let (sign_byte, uint_bytes) = buf - .split_first() - .ok_or_else(|| vm.new_value_error("EOF where object expected.".to_owned()))?; - let sign = match sign_byte { - b'-' => Sign::Minus, - b'0' => Sign::NoSign, - b'+' => Sign::Plus, - _ => { - return Err(vm.new_value_error( - "Unknown sign byte when trying to unmarshal integer".to_owned(), - )) - } + if buf.len() < 4 { + return Err(too_short_error(vm)); + } + let (len_bytes, buf) = buf.split_at(4); + let len = i32::from_le_bytes(len_bytes.try_into().unwrap()); + let (sign, len) = if len < 0 { + (Sign::Minus, (-len) as usize) + } else { + (Sign::Plus, len as usize) }; - let pyint = BigInt::from_bytes_le(sign, uint_bytes); - Ok(pyint.to_pyobject(vm)) + if buf.len() < len { + return Err(too_short_error(vm)); + } + let (bytes, buf) = buf.split_at(len); + let int = BigInt::from_bytes_le(sign, bytes); + (int.to_pyobject(vm), buf) } FLOAT_BYTE => { - let number = f64::from_le_bytes(match buf[..].try_into() { - Ok(byte_array) => byte_array, - Err(e) => { - return Err(vm.new_value_error(format!( - "Expected float, could not load from bytes. {}", - e - ))) - } - }); - let pyfloat = PyFloat::from(number); - Ok(pyfloat.to_pyobject(vm)) + if buf.len() < 8 { + return Err(too_short_error(vm)); + } + let (bytes, buf) = buf.split_at(8); + let number = f64::from_le_bytes(bytes.try_into().unwrap()); + (vm.ctx.new_float(number).into(), buf) } STR_BYTE => { - let pystr = PyStr::from(match AsciiStr::from_ascii(buf) { - Ok(ascii_str) => ascii_str, - Err(e) => { - return Err( - vm.new_value_error(format!("Cannot unmarshal bytes to string, {}", e)) - ) - } - }); - Ok(pystr.to_pyobject(vm)) + let (len, buf) = read_size(buf, vm)?; + if buf.len() < len { + return Err(too_short_error(vm)); + } + let (bytes, buf) = buf.split_at(len); + let s = String::from_utf8(bytes.to_vec()) + .map_err(|_| vm.new_value_error("invalid utf8 data".to_owned()))?; + (s.to_pyobject(vm), buf) } LIST_BYTE => { - let elements = read_list(buf, vm)?; - Ok(elements.to_pyobject(vm)) + let (elements, buf) = load_seq(buf, vm)?; + (vm.ctx.new_list(elements).into(), buf) } SET_BYTE => { - let elements = read_list(buf, vm)?; + let (elements, buf) = load_seq(buf, vm)?; let set = PySet::new_ref(&vm.ctx); for element in elements { set.add(element, vm)?; } - Ok(set.to_pyobject(vm)) + (set.to_pyobject(vm), buf) } FROZEN_SET_BYTE => { - let elements = read_list(buf, vm)?; + let (elements, buf) = load_seq(buf, vm)?; let set = PyFrozenSet::from_iter(vm, elements.into_iter())?; - Ok(set.to_pyobject(vm)) + (set.to_pyobject(vm), buf) } TUPLE_BYTE => { - let elements = read_list(buf, vm)?; - let pytuple = PyTuple::new_ref(elements, &vm.ctx).to_pyobject(vm); - Ok(pytuple) + let (elements, buf) = load_seq(buf, vm)?; + (vm.ctx.new_tuple(elements).into(), buf) } DICT_BYTE => { - let pybuffer = PyBuffer::from_byte_vector(buf[..].to_vec(), vm); - let pydict = match_class!(match loads(pybuffer, vm)? { - pylist @ PyList => from_tuples(pylist.borrow_vec().iter(), vm)?, - _ => - return Err(vm.new_value_error("Couldn't unmarshal dicitionary.".to_owned())), - }); - Ok(pydict.to_pyobject(vm)) + let (len, mut buf) = read_size(buf, vm)?; + let dict = vm.ctx.new_dict(); + for _ in 0..len { + let (key, rest) = load_obj(buf, vm)?; + let (value, rest) = load_obj(rest, vm)?; + buf = rest; + dict.set_item(key.as_object(), value, vm)?; + } + (dict.into(), buf) } BYTE_ARRAY => { // Following CPython, after marshaling, byte arrays are converted into bytes. - let byte_array = PyBytes::from(buf[..].to_vec()); - Ok(byte_array.to_pyobject(vm)) + let (len, buf) = read_size(buf, vm)?; + if buf.len() < len { + return Err(too_short_error(vm)); + } + let (bytes, buf) = buf.split_at(len); + (vm.ctx.new_bytes(bytes.to_vec()).into(), buf) } - _ => { + TYPE_CODE => { // If prefix is not identifiable, assume CodeObject, error out if it doesn't match. - let code = bytecode::CodeObject::from_bytes(&full_buff).map_err(|e| match e { + let (len, buf) = read_size(buf, vm)?; + if buf.len() < len { + return Err(too_short_error(vm)); + } + let (bytes, buf) = buf.split_at(len); + let code = bytecode::CodeObject::from_bytes(bytes).map_err(|e| match e { bytecode::CodeDeserializeError::Eof => vm.new_exception_msg( vm.ctx.exceptions.eof_error.to_owned(), "End of file while deserializing bytecode".to_owned(), ), _ => vm.new_value_error("Couldn't deserialize python bytecode".to_owned()), })?; - Ok(vm.ctx.new_code(code).into()) + (vm.ctx.new_code(code).into(), buf) } - } + _ => return Err(vm.new_value_error("bad marshal data (unknown type code)".to_owned())), + }; + Ok((obj, buf)) } #[pyfunction] From aa58cd2e75c5f2ceba9f5da350f3562d19658118 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sun, 7 Aug 2022 03:10:24 +0900 Subject: [PATCH 3/3] Use enum marshal::Type instead of u8 --- vm/src/stdlib/marshal.rs | 133 +++++++++++++++++++++++++++------------ 1 file changed, 94 insertions(+), 39 deletions(-) diff --git a/vm/src/stdlib/marshal.rs b/vm/src/stdlib/marshal.rs index 9615eec95..ae1cfa27b 100644 --- a/vm/src/stdlib/marshal.rs +++ b/vm/src/stdlib/marshal.rs @@ -19,18 +19,72 @@ mod decl { use num_bigint::{BigInt, Sign}; use num_traits::Zero; - const STR_BYTE: u8 = b's'; - const INT_BYTE: u8 = b'i'; - const FLOAT_BYTE: u8 = b'f'; - const TRUE_BYTE: u8 = b'T'; - const FALSE_BYTE: u8 = b'F'; - const LIST_BYTE: u8 = b'['; - const TUPLE_BYTE: u8 = b'('; - const DICT_BYTE: u8 = b','; - const SET_BYTE: u8 = b'~'; - const FROZEN_SET_BYTE: u8 = b'<'; - const BYTE_ARRAY: u8 = b'>'; - const TYPE_CODE: u8 = b'c'; + #[repr(u8)] + enum Type { + // Null = b'0', + // None = b'N', + False = b'F', + True = b'T', + // StopIter = b'S', + // Ellipsis = b'.', + Int = b'i', + Float = b'g', + // Complex = b'y', + // Long = b'l', // i32 + Bytes = b's', // = TYPE_STRING + // Interned = b't', + // Ref = b'r', + Tuple = b'(', + List = b'[', + Dict = b'{', + Code = b'c', + Str = b'u', // = TYPE_UNICODE + // Unknown = b'?', + Set = b'<', + FrozenSet = b'>', + // Ascii = b'a', + // AsciiInterned = b'A', + // SmallTuple = b')', + // ShortAscii = b'z', + // ShortAsciiInterned = b'Z', + } + // const FLAG_REF: u8 = b'\x80'; + + impl TryFrom for Type { + type Error = u8; + fn try_from(value: u8) -> Result { + use Type::*; + Ok(match value { + // b'0' => Null, + // b'N' => None, + b'F' => False, + b'T' => True, + // b'S' => StopIter, + // b'.' => Ellipsis, + b'i' => Int, + b'g' => Float, + // b'y' => Complex, + // b'l' => Long, + b's' => Bytes, + // b't' => Interned, + // b'r' => Ref, + b'(' => Tuple, + b'[' => List, + b'{' => Dict, + b'c' => Code, + b'u' => Str, + // b'?' => Unknown, + b'<' => Set, + b'>' => FrozenSet, + // b'a' => Ascii, + // b'A' => AsciiInterned, + // b')' => SmallTuple, + // b'z' => ShortAscii, + // b'Z' => ShortAsciiInterned, + c => return Err(c), + }) + } + } fn too_short_error(vm: &VirtualMachine) -> PyBaseExceptionRef { vm.new_exception_msg( @@ -59,13 +113,13 @@ mod decl { pyint @ PyInt => { if pyint.class().is(vm.ctx.types.bool_type) { let typ = if pyint.as_bigint().is_zero() { - FALSE_BYTE + Type::False } else { - TRUE_BYTE + Type::True }; - buf.push(typ); + buf.push(typ as u8); } else { - buf.push(INT_BYTE); + buf.push(Type::Int as u8); let (sign, int_bytes) = pyint.as_bigint().to_bytes_le(); let mut len = int_bytes.len() as i32; if sign == Sign::Minus { @@ -76,35 +130,35 @@ mod decl { } } pyfloat @ PyFloat => { - buf.push(FLOAT_BYTE); + buf.push(Type::Float as u8); buf.extend(pyfloat.to_f64().to_le_bytes()); } pystr @ PyStr => { - buf.push(STR_BYTE); + buf.push(Type::Str as u8); write_size(buf, pystr.as_str().len(), vm)?; buf.extend(pystr.as_str().as_bytes()); } pylist @ PyList => { - buf.push(LIST_BYTE); + buf.push(Type::List as u8); let pylist_items = pylist.borrow_vec(); dump_seq(buf, pylist_items.iter(), vm)?; } pyset @ PySet => { - buf.push(SET_BYTE); + buf.push(Type::Set as u8); let elements = pyset.elements(); dump_seq(buf, elements.iter(), vm)?; } pyfrozen @ PyFrozenSet => { - buf.push(FROZEN_SET_BYTE); + buf.push(Type::FrozenSet as u8); let elements = pyfrozen.elements(); dump_seq(buf, elements.iter(), vm)?; } pytuple @ PyTuple => { - buf.push(TUPLE_BYTE); + buf.push(Type::Tuple as u8); dump_seq(buf, pytuple.iter(), vm)?; } pydict @ PyDict => { - buf.push(DICT_BYTE); + buf.push(Type::Dict as u8); write_size(buf, pydict.len(), vm)?; for (key, value) in pydict { dump_obj(buf, key, vm)?; @@ -112,13 +166,13 @@ mod decl { } } bytes @ PyByteArray => { - buf.push(BYTE_ARRAY); + buf.push(Type::Bytes as u8); let data = bytes.borrow_buf(); write_size(buf, data.len(), vm)?; buf.extend(&*data); } co @ PyCode => { - buf.push(TYPE_CODE); + buf.push(Type::Code as u8); let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes(); write_size(buf, bytes.len(), vm)?; buf.extend(bytes); @@ -191,10 +245,12 @@ mod decl { fn load_obj<'b>(buf: &'b [u8], vm: &VirtualMachine) -> PyResult<(PyObjectRef, &'b [u8])> { let (type_indicator, buf) = buf.split_first().ok_or_else(|| too_short_error(vm))?; - let (obj, buf) = match *type_indicator { - TRUE_BYTE => ((true).to_pyobject(vm), buf), - FALSE_BYTE => ((false).to_pyobject(vm), buf), - INT_BYTE => { + let typ = Type::try_from(*type_indicator) + .map_err(|_| vm.new_value_error("bad marshal data (unknown type code)".to_owned()))?; + let (obj, buf) = match typ { + Type::True => ((true).to_pyobject(vm), buf), + Type::False => ((false).to_pyobject(vm), buf), + Type::Int => { if buf.len() < 4 { return Err(too_short_error(vm)); } @@ -212,7 +268,7 @@ mod decl { let int = BigInt::from_bytes_le(sign, bytes); (int.to_pyobject(vm), buf) } - FLOAT_BYTE => { + Type::Float => { if buf.len() < 8 { return Err(too_short_error(vm)); } @@ -220,7 +276,7 @@ mod decl { let number = f64::from_le_bytes(bytes.try_into().unwrap()); (vm.ctx.new_float(number).into(), buf) } - STR_BYTE => { + Type::Str => { let (len, buf) = read_size(buf, vm)?; if buf.len() < len { return Err(too_short_error(vm)); @@ -230,11 +286,11 @@ mod decl { .map_err(|_| vm.new_value_error("invalid utf8 data".to_owned()))?; (s.to_pyobject(vm), buf) } - LIST_BYTE => { + Type::List => { let (elements, buf) = load_seq(buf, vm)?; (vm.ctx.new_list(elements).into(), buf) } - SET_BYTE => { + Type::Set => { let (elements, buf) = load_seq(buf, vm)?; let set = PySet::new_ref(&vm.ctx); for element in elements { @@ -242,16 +298,16 @@ mod decl { } (set.to_pyobject(vm), buf) } - FROZEN_SET_BYTE => { + Type::FrozenSet => { let (elements, buf) = load_seq(buf, vm)?; let set = PyFrozenSet::from_iter(vm, elements.into_iter())?; (set.to_pyobject(vm), buf) } - TUPLE_BYTE => { + Type::Tuple => { let (elements, buf) = load_seq(buf, vm)?; (vm.ctx.new_tuple(elements).into(), buf) } - DICT_BYTE => { + Type::Dict => { let (len, mut buf) = read_size(buf, vm)?; let dict = vm.ctx.new_dict(); for _ in 0..len { @@ -262,7 +318,7 @@ mod decl { } (dict.into(), buf) } - BYTE_ARRAY => { + Type::Bytes => { // Following CPython, after marshaling, byte arrays are converted into bytes. let (len, buf) = read_size(buf, vm)?; if buf.len() < len { @@ -271,7 +327,7 @@ mod decl { let (bytes, buf) = buf.split_at(len); (vm.ctx.new_bytes(bytes.to_vec()).into(), buf) } - TYPE_CODE => { + Type::Code => { // If prefix is not identifiable, assume CodeObject, error out if it doesn't match. let (len, buf) = read_size(buf, vm)?; if buf.len() < len { @@ -287,7 +343,6 @@ mod decl { })?; (vm.ctx.new_code(code).into(), buf) } - _ => return Err(vm.new_value_error("bad marshal data (unknown type code)".to_owned())), }; Ok((obj, buf)) }