From c366c2074cb4bb83676553f46f2c2f3ce6d17ad4 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Sat, 1 Feb 2020 03:58:10 +0900 Subject: [PATCH] Fix type() error message and BaseException.__str__ --- Lib/test/test_types.py | 2 -- tests/snippets/exceptions.py | 2 ++ vm/src/exceptions.rs | 18 +++++++++++++++++- vm/src/obj/objtype.rs | 21 ++++++++++++++------- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index 12d8b289d..9aa1ef9c5 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -1168,8 +1168,6 @@ class ClassCreationTests(unittest.TestCase): with self.assertRaises(TypeError): X = types.new_class("X", (int(), C)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_one_argument_type(self): expected_message = 'type.__new__() takes exactly 3 arguments (1 given)' diff --git a/tests/snippets/exceptions.py b/tests/snippets/exceptions.py index 60ba5b23a..8ade9c8de 100644 --- a/tests/snippets/exceptions.py +++ b/tests/snippets/exceptions.py @@ -15,6 +15,8 @@ exc = KeyError('message') assert str(exc) == "'message'" assert round_trip_repr(exc) +assert LookupError.__str__(exc) == "message" + exc = KeyError('message', 'another message') assert str(exc) == "('message', 'another message')" assert round_trip_repr(exc) diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index f2f95fb22..ed616e9c1 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -131,7 +131,7 @@ impl PyBaseException { #[pymethod(name = "__str__")] fn str(&self, vm: &VirtualMachine) -> PyStringRef { - let str_args = exception_args_as_string(vm, self.args(), false); + let str_args = exception_args_as_string(vm, self.args(), true); match str_args.into_iter().exactly_one() { Err(i) if i.len() == 0 => PyString::from("").into_ref(vm), Ok(s) => s, @@ -616,6 +616,18 @@ fn make_arg_getter(idx: usize) -> impl Fn(PyBaseExceptionRef, &VirtualMachine) - } } +fn key_error_str(exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyStringRef { + let args = exc.args(); + if args.as_slice().len() == 1 { + exception_args_as_string(vm, args, false) + .into_iter() + .exactly_one() + .unwrap() + } else { + exc.str(vm) + } +} + pub fn init(ctx: &PyContext) { let excs = &ctx.exceptions; @@ -638,6 +650,10 @@ pub fn init(ctx: &PyContext) { "value" => ctx.new_property(make_arg_getter(0)), }); + extend_class!(ctx, &excs.key_error, { + "__str__" => ctx.new_method(key_error_str), + }); + extend_class!(ctx, &excs.unicode_decode_error, { "encoding" => ctx.new_property(make_arg_getter(0)), "object" => ctx.new_property(make_arg_getter(1)), diff --git a/vm/src/obj/objtype.rs b/vm/src/obj/objtype.rs index d8ddb06f7..1e61c6e4d 100644 --- a/vm/src/obj/objtype.rs +++ b/vm/src/obj/objtype.rs @@ -254,13 +254,20 @@ impl PyClassRef { fn tp_new(metatype: PyClassRef, args: PyFuncArgs, vm: &VirtualMachine) -> PyResult { vm_trace!("type.__new__ {:?}", args); - if metatype.is(&vm.ctx.types.type_type) { - if args.args.len() == 1 && args.kwargs.is_empty() { - return Ok(args.args[0].class().into_object()); - } - if args.args.len() != 3 { - return Err(vm.new_type_error("type() takes 1 or 3 arguments".to_string())); - } + let is_type_type = metatype.is(&vm.ctx.types.type_type); + if is_type_type && args.args.len() == 1 && args.kwargs.is_empty() { + return Ok(args.args[0].class().into_object()); + } + + if args.args.len() != 3 { + return Err(vm.new_type_error(if is_type_type { + "type() takes 1 or 3 arguments".to_string() + } else { + format!( + "type.__new__() takes exactly 3 arguments ({} given)", + args.args.len() + ) + })); } let (name, bases, dict): (PyStringRef, PyIterable, PyDictRef) =