Merge pull request #4006 from youknowone/marshal

a few more marshal
This commit is contained in:
Jeong YunWon
2022-08-07 05:20:24 +09:00
committed by GitHub
2 changed files with 126 additions and 108 deletions

View File

@@ -64,8 +64,6 @@ class IntTestCase(unittest.TestCase, HelperMixin):
self.helper(b)
class FloatTestCase(unittest.TestCase, HelperMixin):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_floats(self):
# Test a few floats
small = 1e-25
@@ -101,8 +99,6 @@ class StringTestCase(unittest.TestCase, HelperMixin):
for s in ["", "Andr\xe8 Previn", "abc", " "*10000]:
self.helper(s)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_bytes(self):
for s in [b"", b"Andr\xe8 Previn", b"abc", b" "*10000]:
self.helper(s)
@@ -202,14 +198,11 @@ class BugsTestCase(unittest.TestCase):
self.assertRaises(Exception, marshal.loads, b'f')
self.assertRaises(Exception, marshal.loads, marshal.dumps(2**65)[:-1])
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_version_argument(self):
# Python 2.4.0 crashes for any call to marshal.dumps(x, y)
self.assertEqual(marshal.loads(marshal.dumps(5, 0)), 5)
self.assertEqual(marshal.loads(marshal.dumps(5, 1)), 5)
@unittest.skip("TODO: RUSTPYTHON; panic")
def test_fuzz(self):
# simple test that it's at least not *totally* trivial to
# crash from bad marshal data
@@ -337,8 +330,6 @@ class BugsTestCase(unittest.TestCase):
self.assertRaises(ValueError, marshal.load,
BadReader(marshal.dumps(value)))
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_eof(self):
data = marshal.dumps(("hello", "dolly", None))
for i in range(len(data)):
@@ -509,8 +500,7 @@ class InstancingTestCase(unittest.TestCase, HelperMixin):
self.helper(code)
self.helper3(code)
# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.skip("TODO: RUSTPYTHON")
def testRecursion(self):
obj = 1.2345
d = {"hello": obj, "goodbye": obj, obj: "hello"}
@@ -529,23 +519,15 @@ class CompatibilityTestCase(unittest.TestCase):
data = marshal.dumps(code, version)
marshal.loads(data)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test0To3(self):
self._test(0)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test1To3(self):
self._test(1)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test2To3(self):
self._test(2)
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test3To3(self):
self._test(3)
@@ -562,8 +544,6 @@ class InterningTestCase(unittest.TestCase, HelperMixin):
s2 = sys.intern(s)
self.assertEqual(id(s2), id(s))
# TODO: RUSTPYTHON
@unittest.expectedFailure
def testNoIntern(self):
s = marshal.loads(marshal.dumps(self.strobj, 2))
self.assertEqual(s, self.strobj)

View File

@@ -9,24 +9,22 @@ mod decl {
},
bytecode,
convert::ToPyObject,
function::ArgBytesLike,
function::{ArgBytesLike, OptionalArg},
object::AsObject,
protocol::PyBuffer,
PyObjectRef, PyResult, TryFromObject, VirtualMachine,
};
/// TODO
/// PyBytes: Currently getting recursion error with match_class!
use num_bigint::{BigInt, Sign};
use num_traits::Zero;
#[repr(u8)]
enum Type {
// Null = b'0',
// None = b'N',
None = b'N',
False = b'F',
True = b'T',
// StopIter = b'S',
// Ellipsis = b'.',
Ellipsis = b'.',
Int = b'i',
Float = b'g',
// Complex = b'y',
@@ -38,11 +36,11 @@ mod decl {
List = b'[',
Dict = b'{',
Code = b'c',
Str = b'u', // = TYPE_UNICODE
Unicode = b'u',
// Unknown = b'?',
Set = b'<',
FrozenSet = b'>',
// Ascii = b'a',
Ascii = b'a',
// AsciiInterned = b'A',
// SmallTuple = b')',
// ShortAscii = b'z',
@@ -56,11 +54,11 @@ mod decl {
use Type::*;
Ok(match value {
// b'0' => Null,
// b'N' => None,
b'N' => None,
b'F' => False,
b'T' => True,
// b'S' => StopIter,
// b'.' => Ellipsis,
b'.' => Ellipsis,
b'i' => Int,
b'g' => Float,
// b'y' => Complex,
@@ -72,11 +70,11 @@ mod decl {
b'[' => List,
b'{' => Dict,
b'c' => Code,
b'u' => Str,
b'u' => Unicode,
// b'?' => Unknown,
b'<' => Set,
b'>' => FrozenSet,
// b'a' => Ascii,
b'a' => Ascii,
// b'A' => AsciiInterned,
// b')' => SmallTuple,
// b'z' => ShortAscii,
@@ -86,6 +84,9 @@ mod decl {
}
}
#[pyattr(name = "version")]
const VERSION: u32 = 4;
fn too_short_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(
vm.ctx.exceptions.eof_error.to_owned(),
@@ -109,93 +110,118 @@ mod decl {
/// Dumping helper function to turn a value into bytes.
fn dump_obj(buf: &mut Vec<u8>, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
match_class!(match value {
pyint @ PyInt => {
if pyint.class().is(vm.ctx.types.bool_type) {
let typ = if pyint.as_bigint().is_zero() {
Type::False
if vm.is_none(&value) {
buf.push(Type::None as u8);
} else if value.is(&vm.ctx.ellipsis) {
buf.push(Type::Ellipsis as u8);
} else {
match_class!(match value {
pyint @ PyInt => {
if pyint.class().is(vm.ctx.types.bool_type) {
let typ = if pyint.as_bigint().is_zero() {
Type::False
} else {
Type::True
};
buf.push(typ as u8);
} else {
Type::True
};
buf.push(typ as u8);
} else {
buf.push(Type::Int as u8);
let (sign, int_bytes) = pyint.as_bigint().to_bytes_le();
let mut len = int_bytes.len() as i32;
if sign == Sign::Minus {
len = -len;
buf.push(Type::Int as u8);
let (sign, int_bytes) = pyint.as_bigint().to_bytes_le();
let mut len = int_bytes.len() as i32;
if sign == Sign::Minus {
len = -len;
}
buf.extend(len.to_le_bytes());
buf.extend(int_bytes);
}
buf.extend(len.to_le_bytes());
buf.extend(int_bytes);
}
}
pyfloat @ PyFloat => {
buf.push(Type::Float as u8);
buf.extend(pyfloat.to_f64().to_le_bytes());
}
pystr @ PyStr => {
buf.push(Type::Str as u8);
write_size(buf, pystr.as_str().len(), vm)?;
buf.extend(pystr.as_str().as_bytes());
}
pylist @ PyList => {
buf.push(Type::List as u8);
let pylist_items = pylist.borrow_vec();
dump_seq(buf, pylist_items.iter(), vm)?;
}
pyset @ PySet => {
buf.push(Type::Set as u8);
let elements = pyset.elements();
dump_seq(buf, elements.iter(), vm)?;
}
pyfrozen @ PyFrozenSet => {
buf.push(Type::FrozenSet as u8);
let elements = pyfrozen.elements();
dump_seq(buf, elements.iter(), vm)?;
}
pytuple @ PyTuple => {
buf.push(Type::Tuple as u8);
dump_seq(buf, pytuple.iter(), vm)?;
}
pydict @ PyDict => {
buf.push(Type::Dict as u8);
write_size(buf, pydict.len(), vm)?;
for (key, value) in pydict {
dump_obj(buf, key, vm)?;
dump_obj(buf, value, vm)?;
pyfloat @ PyFloat => {
buf.push(Type::Float as u8);
buf.extend(pyfloat.to_f64().to_le_bytes());
}
}
bytes @ PyByteArray => {
buf.push(Type::Bytes as u8);
let data = bytes.borrow_buf();
write_size(buf, data.len(), vm)?;
buf.extend(&*data);
}
co @ PyCode => {
buf.push(Type::Code as u8);
let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes();
write_size(buf, bytes.len(), vm)?;
buf.extend(bytes);
}
_ => {
return Err(vm.new_not_implemented_error(
"TODO: not implemented yet or marshal unsupported type".to_owned(),
));
}
});
pystr @ PyStr => {
buf.push(if pystr.is_ascii() {
Type::Ascii
} else {
Type::Unicode
} as u8);
write_size(buf, pystr.as_str().len(), vm)?;
buf.extend(pystr.as_str().as_bytes());
}
pylist @ PyList => {
buf.push(Type::List as u8);
let pylist_items = pylist.borrow_vec();
dump_seq(buf, pylist_items.iter(), vm)?;
}
pyset @ PySet => {
buf.push(Type::Set as u8);
let elements = pyset.elements();
dump_seq(buf, elements.iter(), vm)?;
}
pyfrozen @ PyFrozenSet => {
buf.push(Type::FrozenSet as u8);
let elements = pyfrozen.elements();
dump_seq(buf, elements.iter(), vm)?;
}
pytuple @ PyTuple => {
buf.push(Type::Tuple as u8);
dump_seq(buf, pytuple.iter(), vm)?;
}
pydict @ PyDict => {
buf.push(Type::Dict as u8);
write_size(buf, pydict.len(), vm)?;
for (key, value) in pydict {
dump_obj(buf, key, vm)?;
dump_obj(buf, value, vm)?;
}
}
bytes @ PyBytes => {
buf.push(Type::Bytes as u8);
let data = bytes.as_bytes();
write_size(buf, data.len(), vm)?;
buf.extend(&*data);
}
bytes @ PyByteArray => {
buf.push(Type::Bytes as u8);
let data = bytes.borrow_buf();
write_size(buf, data.len(), vm)?;
buf.extend(&*data);
}
co @ PyCode => {
buf.push(Type::Code as u8);
let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes();
write_size(buf, bytes.len(), vm)?;
buf.extend(bytes);
}
_ => {
return Err(vm.new_not_implemented_error(
"TODO: not implemented yet or marshal unsupported type".to_owned(),
));
}
})
}
Ok(())
}
#[pyfunction]
fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyBytes> {
fn dumps(
value: PyObjectRef,
_version: OptionalArg<i32>,
vm: &VirtualMachine,
) -> PyResult<PyBytes> {
let mut buf = Vec::new();
dump_obj(&mut buf, value, vm)?;
Ok(PyBytes::from(buf))
}
#[pyfunction]
fn dump(value: PyObjectRef, f: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
let dumped = dumps(value, vm)?;
fn dump(
value: PyObjectRef,
f: PyObjectRef,
version: OptionalArg<i32>,
vm: &VirtualMachine,
) -> PyResult<()> {
let dumped = dumps(value, version, vm)?;
vm.call_method(&f, "write", (dumped,))?;
Ok(())
}
@@ -248,8 +274,10 @@ mod decl {
let typ = Type::try_from(*type_indicator)
.map_err(|_| vm.new_value_error("bad marshal data (unknown type code)".to_owned()))?;
let (obj, buf) = match typ {
Type::True => ((true).to_pyobject(vm), buf),
Type::False => ((false).to_pyobject(vm), buf),
Type::True => (true.to_pyobject(vm), buf),
Type::False => (false.to_pyobject(vm), buf),
Type::None => (vm.ctx.none(), buf),
Type::Ellipsis => (vm.ctx.ellipsis(), buf),
Type::Int => {
if buf.len() < 4 {
return Err(too_short_error(vm));
@@ -276,7 +304,17 @@ mod decl {
let number = f64::from_le_bytes(bytes.try_into().unwrap());
(vm.ctx.new_float(number).into(), buf)
}
Type::Str => {
Type::Ascii => {
let (len, buf) = read_size(buf, vm)?;
if buf.len() < len {
return Err(too_short_error(vm));
}
let (bytes, buf) = buf.split_at(len);
let s = String::from_utf8(bytes.to_vec())
.map_err(|_| vm.new_value_error("invalid utf8 data".to_owned()))?;
(s.to_pyobject(vm), buf)
}
Type::Unicode => {
let (len, buf) = read_size(buf, vm)?;
if buf.len() < len {
return Err(too_short_error(vm));