mirror of
https://github.com/RustPython/RustPython.git
synced 2026-06-02 19:39:49 +09:00
Fix stack overflow on deeply-nested JSON in json.loads() (#7632)
* Fix stack overflow on deeply-nested JSON in json.loads()
json.loads() on a deeply-nested array or object payload (e.g.
'[' * 50000 + ']' * 50000) overflowed the native Rust stack and
crashed the interpreter process with SIGSEGV. CPython raises
RecursionError on the same input via _Py_EnterRecursiveCall in
Modules/_json.c.
The recursion lives in the mutual call chain:
JsonScanner::parse_object / parse_array
-> JsonScanner::call_scan_once
-> JsonScanner::parse_object / parse_array
Every descent funnels through call_scan_once, so wrapping its body
with vm.with_recursion covers both '{' and '[' paths (and their
mixed nesting) with a single guard.
Before:
./rustpython -c "import json; json.loads('[' * 50000 + ']' * 50000)"
-> SIGSEGV (exit 139)
After:
-> RecursionError: maximum recursion depth exceeded while
decoding a JSON object from a string
Verified:
- extra_tests/snippets/stdlib_json.py: all assertions pass
(includes 3 new regression cases: array, object, alternating
nesting at depth 100000)
- cargo run -- -m test test_json: 214 passed, 0 regressed
(9 skipped, 13 expected failures, all pre-existing)
- depth 500000 no longer crashes (RecursionError)
- shallow parsing unchanged
* Enable test_highly_nested_objects_decoding
Per @ShaharNaveh's review on #7632: this test was previously marked
`@unittest.skip("TODO: RUSTPYTHON; crashes")` because json.loads
would SIGSEGV on the 500_000-deep input. The recursion-guard added
in this PR makes it raise RecursionError like CPython, so the skip
decorator can be removed.
$ cargo run -- -m unittest \
test.test_json.test_recursion.TestCRecursion.test_highly_nested_objects_decoding \
test.test_json.test_recursion.TestPyRecursion.test_highly_nested_objects_decoding
...
Ran 2 tests in 0.825s
OK
$ cargo run -- -m test test_json
Ran 214 tests (7 skipped, 13 expected failures) — all pass.
This commit is contained in:
1
Lib/test/test_json/test_recursion.py
vendored
1
Lib/test/test_json/test_recursion.py
vendored
@@ -70,7 +70,6 @@ class TestRecursion:
|
||||
self.fail("didn't raise ValueError on default recursion")
|
||||
|
||||
|
||||
@unittest.skip("TODO: RUSTPYTHON; crashes")
|
||||
@support.skip_if_unlimited_stack_size
|
||||
@support.skip_emscripten_stack_overflow()
|
||||
@support.skip_wasi_stack_overflow()
|
||||
|
||||
@@ -513,107 +513,116 @@ mod _json {
|
||||
memo: &mut HashMap<String, PyStrRef>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<(PyObjectRef, usize, usize)> {
|
||||
let bytes = pystr.as_bytes();
|
||||
let wtf8 = pystr.as_wtf8();
|
||||
let s = pystr.as_str();
|
||||
// Recursion guard: parse_object/parse_array recurse into call_scan_once
|
||||
// for each child value. Without this, a deeply-nested input like
|
||||
// `'[' * 50000 + ']' * 50000` overflows the native Rust stack and
|
||||
// crashes the process with SIGSEGV. Matches CPython's
|
||||
// _Py_EnterRecursiveCall in Modules/_json.c.
|
||||
vm.with_recursion("while decoding a JSON object from a string", || {
|
||||
let bytes = pystr.as_bytes();
|
||||
let wtf8 = pystr.as_wtf8();
|
||||
let s = pystr.as_str();
|
||||
|
||||
let first_byte = match bytes.get(byte_idx) {
|
||||
Some(&b) => b,
|
||||
None => return Err(self.make_decode_error("Expecting value", pystr, char_idx, vm)),
|
||||
};
|
||||
let first_byte = match bytes.get(byte_idx) {
|
||||
Some(&b) => b,
|
||||
None => {
|
||||
return Err(self.make_decode_error("Expecting value", pystr, char_idx, vm));
|
||||
}
|
||||
};
|
||||
|
||||
match first_byte {
|
||||
b'"' => {
|
||||
// String - pass slice starting after the quote
|
||||
let (wtf8_result, chars_consumed, bytes_consumed) =
|
||||
machinery::scanstring(&wtf8[byte_idx + 1..], char_idx + 1, self.strict)
|
||||
.map_err(|e| py_decode_error(e, pystr.clone().into_wtf8(), vm))?;
|
||||
let py_str = vm.ctx.new_str(wtf8_result.to_string());
|
||||
Ok((
|
||||
py_str.into(),
|
||||
char_idx + 1 + chars_consumed,
|
||||
byte_idx + 1 + bytes_consumed,
|
||||
))
|
||||
}
|
||||
b'{' => {
|
||||
// Object
|
||||
self.parse_object(pystr, char_idx + 1, byte_idx + 1, scan_once, memo, vm)
|
||||
}
|
||||
b'[' => {
|
||||
// Array
|
||||
self.parse_array(pystr, char_idx + 1, byte_idx + 1, scan_once, memo, vm)
|
||||
}
|
||||
b'n' if starts_with_bytes(&bytes[byte_idx..], b"null") => {
|
||||
// null
|
||||
Ok((vm.ctx.none(), char_idx + 4, byte_idx + 4))
|
||||
}
|
||||
b't' if starts_with_bytes(&bytes[byte_idx..], b"true") => {
|
||||
// true
|
||||
Ok((vm.ctx.new_bool(true).into(), char_idx + 4, byte_idx + 4))
|
||||
}
|
||||
b'f' if starts_with_bytes(&bytes[byte_idx..], b"false") => {
|
||||
// false
|
||||
Ok((vm.ctx.new_bool(false).into(), char_idx + 5, byte_idx + 5))
|
||||
}
|
||||
b'N' if starts_with_bytes(&bytes[byte_idx..], b"NaN") => {
|
||||
// NaN
|
||||
let result = self.parse_constant.call(("NaN",), vm)?;
|
||||
Ok((result, char_idx + 3, byte_idx + 3))
|
||||
}
|
||||
b'I' if starts_with_bytes(&bytes[byte_idx..], b"Infinity") => {
|
||||
// Infinity
|
||||
let result = self.parse_constant.call(("Infinity",), vm)?;
|
||||
Ok((result, char_idx + 8, byte_idx + 8))
|
||||
}
|
||||
b'-' => {
|
||||
// -Infinity or negative number
|
||||
if starts_with_bytes(&bytes[byte_idx..], b"-Infinity") {
|
||||
let result = self.parse_constant.call(("-Infinity",), vm)?;
|
||||
return Ok((result, char_idx + 9, byte_idx + 9));
|
||||
match first_byte {
|
||||
b'"' => {
|
||||
// String - pass slice starting after the quote
|
||||
let (wtf8_result, chars_consumed, bytes_consumed) =
|
||||
machinery::scanstring(&wtf8[byte_idx + 1..], char_idx + 1, self.strict)
|
||||
.map_err(|e| py_decode_error(e, pystr.clone().into_wtf8(), vm))?;
|
||||
let py_str = vm.ctx.new_str(wtf8_result.to_string());
|
||||
Ok((
|
||||
py_str.into(),
|
||||
char_idx + 1 + chars_consumed,
|
||||
byte_idx + 1 + bytes_consumed,
|
||||
))
|
||||
}
|
||||
// Negative number - numbers are ASCII so len == bytes
|
||||
if let Some((result, len)) = self.parse_number(&s[byte_idx..], vm) {
|
||||
return Ok((result?, char_idx + len, byte_idx + len));
|
||||
b'{' => {
|
||||
// Object
|
||||
self.parse_object(pystr, char_idx + 1, byte_idx + 1, scan_once, memo, vm)
|
||||
}
|
||||
Err(self.make_decode_error("Expecting value", pystr, char_idx, vm))
|
||||
}
|
||||
b'0'..=b'9' => {
|
||||
// Positive number - numbers are ASCII so len == bytes
|
||||
if let Some((result, len)) = self.parse_number(&s[byte_idx..], vm) {
|
||||
return Ok((result?, char_idx + len, byte_idx + len));
|
||||
b'[' => {
|
||||
// Array
|
||||
self.parse_array(pystr, char_idx + 1, byte_idx + 1, scan_once, memo, vm)
|
||||
}
|
||||
Err(self.make_decode_error("Expecting value", pystr, char_idx, vm))
|
||||
}
|
||||
_ => {
|
||||
// Fall back to scan_once for unrecognized input
|
||||
// Note: This path requires char_idx for Python compatibility
|
||||
let result = scan_once.call((pystr.clone(), char_idx as isize), vm);
|
||||
b'n' if starts_with_bytes(&bytes[byte_idx..], b"null") => {
|
||||
// null
|
||||
Ok((vm.ctx.none(), char_idx + 4, byte_idx + 4))
|
||||
}
|
||||
b't' if starts_with_bytes(&bytes[byte_idx..], b"true") => {
|
||||
// true
|
||||
Ok((vm.ctx.new_bool(true).into(), char_idx + 4, byte_idx + 4))
|
||||
}
|
||||
b'f' if starts_with_bytes(&bytes[byte_idx..], b"false") => {
|
||||
// false
|
||||
Ok((vm.ctx.new_bool(false).into(), char_idx + 5, byte_idx + 5))
|
||||
}
|
||||
b'N' if starts_with_bytes(&bytes[byte_idx..], b"NaN") => {
|
||||
// NaN
|
||||
let result = self.parse_constant.call(("NaN",), vm)?;
|
||||
Ok((result, char_idx + 3, byte_idx + 3))
|
||||
}
|
||||
b'I' if starts_with_bytes(&bytes[byte_idx..], b"Infinity") => {
|
||||
// Infinity
|
||||
let result = self.parse_constant.call(("Infinity",), vm)?;
|
||||
Ok((result, char_idx + 8, byte_idx + 8))
|
||||
}
|
||||
b'-' => {
|
||||
// -Infinity or negative number
|
||||
if starts_with_bytes(&bytes[byte_idx..], b"-Infinity") {
|
||||
let result = self.parse_constant.call(("-Infinity",), vm)?;
|
||||
return Ok((result, char_idx + 9, byte_idx + 9));
|
||||
}
|
||||
// Negative number - numbers are ASCII so len == bytes
|
||||
if let Some((result, len)) = self.parse_number(&s[byte_idx..], vm) {
|
||||
return Ok((result?, char_idx + len, byte_idx + len));
|
||||
}
|
||||
Err(self.make_decode_error("Expecting value", pystr, char_idx, vm))
|
||||
}
|
||||
b'0'..=b'9' => {
|
||||
// Positive number - numbers are ASCII so len == bytes
|
||||
if let Some((result, len)) = self.parse_number(&s[byte_idx..], vm) {
|
||||
return Ok((result?, char_idx + len, byte_idx + len));
|
||||
}
|
||||
Err(self.make_decode_error("Expecting value", pystr, char_idx, vm))
|
||||
}
|
||||
_ => {
|
||||
// Fall back to scan_once for unrecognized input
|
||||
// Note: This path requires char_idx for Python compatibility
|
||||
let result = scan_once.call((pystr.clone(), char_idx as isize), vm);
|
||||
|
||||
match result {
|
||||
Ok(tuple) => {
|
||||
use crate::vm::builtins::PyTupleRef;
|
||||
let tuple: PyTupleRef = tuple.try_into_value(vm)?;
|
||||
if tuple.len() != 2 {
|
||||
return Err(vm.new_value_error("scan_once must return 2-tuple"));
|
||||
match result {
|
||||
Ok(tuple) => {
|
||||
use crate::vm::builtins::PyTupleRef;
|
||||
let tuple: PyTupleRef = tuple.try_into_value(vm)?;
|
||||
if tuple.len() != 2 {
|
||||
return Err(vm.new_value_error("scan_once must return 2-tuple"));
|
||||
}
|
||||
let value = tuple.as_slice()[0].clone();
|
||||
let end_char_idx: isize = tuple.as_slice()[1].try_to_value(vm)?;
|
||||
// For fallback, we need to calculate byte_idx from char_idx
|
||||
// This is expensive but fallback should be rare
|
||||
let end_byte_idx = s
|
||||
.char_indices()
|
||||
.nth(end_char_idx as usize)
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or(s.len());
|
||||
Ok((value, end_char_idx as usize, end_byte_idx))
|
||||
}
|
||||
let value = tuple.as_slice()[0].clone();
|
||||
let end_char_idx: isize = tuple.as_slice()[1].try_to_value(vm)?;
|
||||
// For fallback, we need to calculate byte_idx from char_idx
|
||||
// This is expensive but fallback should be rare
|
||||
let end_byte_idx = s
|
||||
.char_indices()
|
||||
.nth(end_char_idx as usize)
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or(s.len());
|
||||
Ok((value, end_char_idx as usize, end_byte_idx))
|
||||
Err(err) if err.fast_isinstance(vm.ctx.exceptions.stop_iteration) => {
|
||||
Err(self.make_decode_error("Expecting value", pystr, char_idx, vm))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
Err(err) if err.fast_isinstance(vm.ctx.exceptions.stop_iteration) => {
|
||||
Err(self.make_decode_error("Expecting value", pystr, char_idx, vm))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a decode error.
|
||||
|
||||
@@ -217,3 +217,25 @@ i = 7**500
|
||||
assert json.dumps(i) == str(i)
|
||||
|
||||
assert json.decoder.scanstring('✨x"', 1) == ("x", 3)
|
||||
|
||||
|
||||
# Recursion guard: deeply-nested input must raise RecursionError instead of
|
||||
# overflowing the native stack (SIGSEGV). Matches CPython's
|
||||
# _Py_EnterRecursiveCall in Modules/_json.c.
|
||||
|
||||
_deep = 100_000 # well above the ~45k native-stack crash threshold
|
||||
|
||||
# Array nesting
|
||||
assert_raises(RecursionError, lambda: json.loads("[" * _deep + "]" * _deep))
|
||||
|
||||
# Object nesting
|
||||
assert_raises(
|
||||
RecursionError,
|
||||
lambda: json.loads('{"a":' * _deep + "1" + "}" * _deep),
|
||||
)
|
||||
|
||||
# Alternating array/object nesting
|
||||
assert_raises(
|
||||
RecursionError,
|
||||
lambda: json.loads(('[{"x":' * _deep) + "1" + ("}]" * _deep)),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user