diff --git a/tests/snippets/json_snippet.py b/tests/snippets/json_snippet.py index d47266b58..31e6648f6 100644 --- a/tests/snippets/json_snippet.py +++ b/tests/snippets/json_snippet.py @@ -1,6 +1,6 @@ from testutils import assert_raises import json -from io import StringIO +from io import StringIO, BytesIO def round_trip_test(obj): # serde_json and Python's json module produce slightly differently spaced @@ -15,7 +15,7 @@ def json_dump(obj): return f.getvalue() def json_load(obj): - f = StringIO(obj) + f = StringIO(obj) if isinstance(obj, str) else BytesIO(bytes(obj)) return json.load(f) assert '"string"' == json.dumps("string") @@ -70,42 +70,92 @@ assert_raises(json.JSONDecodeError, lambda: json_load('{3: "abc"}')) assert json.dumps({'3': 'abc'}) == json.dumps({3: 'abc'}) assert 1 == json.loads("1") +assert 1 == json.loads(b"1") +assert 1 == json.loads(bytearray(b"1")) assert 1 == json_load("1") +assert 1 == json_load(b"1") +assert 1 == json_load(bytearray(b"1")) assert -1 == json.loads("-1") +assert -1 == json.loads(b"-1") +assert -1 == json.loads(bytearray(b"-1")) assert -1 == json_load("-1") +assert -1 == json_load(b"-1") +assert -1 == json_load(bytearray(b"-1")) assert 1.0 == json.loads("1.0") +assert 1.0 == json.loads(b"1.0") +assert 1.0 == json.loads(bytearray(b"1.0")) assert 1.0 == json_load("1.0") +assert 1.0 == json_load(b"1.0") +assert 1.0 == json_load(bytearray(b"1.0")) assert -1.0 == json.loads("-1.0") +assert -1.0 == json.loads(b"-1.0") +assert -1.0 == json.loads(bytearray(b"-1.0")) assert -1.0 == json_load("-1.0") +assert -1.0 == json_load(b"-1.0") +assert -1.0 == json_load(bytearray(b"-1.0")) assert "str" == json.loads('"str"') +assert "str" == json.loads(b'"str"') +assert "str" == json.loads(bytearray(b'"str"')) assert "str" == json_load('"str"') +assert "str" == json_load(b'"str"') +assert "str" == json_load(bytearray(b'"str"')) assert True is json.loads('true') +assert True is json.loads(b'true') +assert True is json.loads(bytearray(b'true')) assert True is json_load('true') +assert True is json_load(b'true') +assert True is json_load(bytearray(b'true')) assert False is json.loads('false') +assert False is json.loads(b'false') +assert False is json.loads(bytearray(b'false')) assert False is json_load('false') +assert False is json_load(b'false') +assert False is json_load(bytearray(b'false')) assert None is json.loads('null') +assert None is json.loads(b'null') +assert None is json.loads(bytearray(b'null')) assert None is json_load('null') +assert None is json_load(b'null') +assert None is json_load(bytearray(b'null')) assert [] == json.loads('[]') +assert [] == json.loads(b'[]') +assert [] == json.loads(bytearray(b'[]')) assert [] == json_load('[]') +assert [] == json_load(b'[]') +assert [] == json_load(bytearray(b'[]')) assert ['a'] == json.loads('["a"]') +assert ['a'] == json.loads(b'["a"]') +assert ['a'] == json.loads(bytearray(b'["a"]')) assert ['a'] == json_load('["a"]') +assert ['a'] == json_load(b'["a"]') +assert ['a'] == json_load(bytearray(b'["a"]')) assert [['a'], 'b'] == json.loads('[["a"], "b"]') +assert [['a'], 'b'] == json.loads(b'[["a"], "b"]') +assert [['a'], 'b'] == json.loads(bytearray(b'[["a"], "b"]')) assert [['a'], 'b'] == json_load('[["a"], "b"]') +assert [['a'], 'b'] == json_load(b'[["a"], "b"]') +assert [['a'], 'b'] == json_load(bytearray(b'[["a"], "b"]')) class String(str): pass +class Bytes(bytes): pass +class ByteArray(bytearray): pass assert "string" == json.loads(String('"string"')) +assert "string" == json.loads(Bytes(b'"string"')) +assert "string" == json.loads(ByteArray(b'"string"')) assert "string" == json_load(String('"string"')) +assert "string" == json_load(Bytes(b'"string"')) +assert "string" == json_load(ByteArray(b'"string"')) assert '"string"' == json.dumps(String("string")) assert '"string"' == json_dump(String("string")) diff --git a/vm/src/stdlib/json.rs b/vm/src/stdlib/json.rs index aa3fb8157..f7102f9e2 100644 --- a/vm/src/stdlib/json.rs +++ b/vm/src/stdlib/json.rs @@ -1,6 +1,8 @@ -use crate::obj::objstr::PyStringRef; +use crate::obj::objbytearray::PyByteArray; +use crate::obj::objbytes::PyBytes; +use crate::obj::objstr::PyString; use crate::py_serde; -use crate::pyobject::{ItemProtocol, PyObjectRef, PyResult}; +use crate::pyobject::{ItemProtocol, PyObjectRef, PyResult, TypeProtocol}; use crate::types::create_type; use crate::VirtualMachine; use serde_json; @@ -18,11 +20,19 @@ pub fn json_dump(obj: PyObjectRef, fs: PyObjectRef, vm: &VirtualMachine) -> PyRe } /// Implement json.loads -pub fn json_loads(string: PyStringRef, vm: &VirtualMachine) -> PyResult { - // TODO: Implement non-trivial deserialization case - let de_result = - py_serde::deserialize(vm, &mut serde_json::Deserializer::from_str(string.as_str())); - +pub fn json_loads(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult { + let de_result = match_class!(obj, + s @ PyString => py_serde::deserialize(vm, &mut serde_json::Deserializer::from_str(s.as_str())), + b @ PyBytes => py_serde::deserialize(vm, &mut serde_json::Deserializer::from_slice(&b)), + ba @ PyByteArray => py_serde::deserialize(vm, &mut serde_json::Deserializer::from_slice(&ba.inner.borrow().elements)), + obj => { + let msg = format!( + "the JSON object must be str, bytes or bytearray, not {}", + obj.class().name + ); + return Err(vm.new_type_error(msg)); + } + ); de_result.map_err(|err| { let module = vm .get_attribute(vm.sys_module.clone(), "modules") @@ -42,7 +52,7 @@ pub fn json_loads(string: PyStringRef, vm: &VirtualMachine) -> PyResult { pub fn json_load(fp: PyObjectRef, vm: &VirtualMachine) -> PyResult { let result = vm.call_method(&fp, "read", vec![])?; - json_loads(result.downcast()?, vm) + json_loads(result, vm) } pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {