diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index 7e9b28236..4200cb8ee 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -838,8 +838,6 @@ class UnionTests(unittest.TestCase): self.assertEqual((list[T] | list[S])[int, T], list[int] | list[T]) self.assertEqual((list[T] | list[S])[int, int], list[int]) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_union_parameter_substitution(self): def eq(actual, expected, typed=True): self.assertEqual(actual, expected) diff --git a/crates/vm/src/builtins/type.rs b/crates/vm/src/builtins/type.rs index 4f150d925..bc567cd09 100644 --- a/crates/vm/src/builtins/type.rs +++ b/crates/vm/src/builtins/type.rs @@ -1963,7 +1963,7 @@ pub(crate) fn call_slot_new( slot_new(subtype, args, vm) } -pub(super) fn or_(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { +pub(crate) fn or_(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { if !union_::is_unionable(zelf.clone(), vm) || !union_::is_unionable(other.clone(), vm) { return vm.ctx.not_implemented(); } diff --git a/crates/vm/src/builtins/union.rs b/crates/vm/src/builtins/union.rs index 8310201a1..83e1316d0 100644 --- a/crates/vm/src/builtins/union.rs +++ b/crates/vm/src/builtins/union.rs @@ -8,6 +8,7 @@ use crate::{ convert::{ToPyObject, ToPyResult}, function::PyComparisonValue, protocol::{PyMappingMethods, PyNumberMethods}, + stdlib::typing::TypeAliasType, types::{AsMapping, AsNumber, Comparable, GetAttr, Hashable, PyComparisonOp, Representable}, }; use std::fmt; @@ -152,10 +153,12 @@ impl PyUnion { } pub fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool { - obj.class().is(vm.ctx.types.none_type) + let cls = obj.class(); + cls.is(vm.ctx.types.none_type) || obj.downcastable::() - || obj.class().is(vm.ctx.types.generic_alias_type) - || obj.class().is(vm.ctx.types.union_type) + || cls.fast_issubclass(vm.ctx.types.generic_alias_type) + || cls.is(vm.ctx.types.union_type) + || obj.downcast_ref::().is_some() } fn make_parameters(args: &Py, vm: &VirtualMachine) -> PyTupleRef { @@ -195,7 +198,7 @@ fn dedup_and_flatten_args(args: &Py, vm: &VirtualMachine) -> PyTupleRef if !new_args.iter().any(|param| { param .rich_compare_bool(arg, PyComparisonOp::Eq, vm) - .expect("types are always comparable") + .unwrap_or_default() }) { new_args.push(arg.clone()); } diff --git a/crates/vm/src/stdlib/typing.rs b/crates/vm/src/stdlib/typing.rs index 469a75b01..f11acce34 100644 --- a/crates/vm/src/stdlib/typing.rs +++ b/crates/vm/src/stdlib/typing.rs @@ -36,9 +36,11 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef { pub(crate) mod decl { use crate::{ Py, PyObjectRef, PyPayload, PyResult, VirtualMachine, - builtins::{PyStrRef, PyTupleRef, PyType, PyTypeRef, pystr::AsPyStr}, + builtins::{PyStrRef, PyTupleRef, PyType, PyTypeRef, pystr::AsPyStr, type_}, + convert::ToPyResult, function::{FuncArgs, IntoFuncArgs}, - types::{Constructor, Representable}, + protocol::PyNumberMethods, + types::{AsNumber, Constructor, Representable}, }; pub(crate) fn _call_typing_func_object<'a>( @@ -109,7 +111,7 @@ pub(crate) mod decl { // compute_value: PyObjectRef, // module: PyObjectRef, } - #[pyclass(with(Constructor, Representable), flags(BASETYPE))] + #[pyclass(with(Constructor, Representable, AsNumber), flags(BASETYPE))] impl TypeAliasType { pub const fn new(name: PyStrRef, type_params: PyTupleRef, value: PyObjectRef) -> Self { Self { @@ -183,6 +185,16 @@ pub(crate) mod decl { } } + impl AsNumber for TypeAliasType { + fn as_number() -> &'static PyNumberMethods { + static AS_NUMBER: PyNumberMethods = PyNumberMethods { + or: Some(|a, b, vm| type_::or_(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)), + ..PyNumberMethods::NOT_IMPLEMENTED + }; + &AS_NUMBER + } + } + // impl AsMapping for Generic { // fn as_mapping() -> &'static PyMappingMethods { // static AS_MAPPING: Lazy = Lazy::new(|| PyMappingMethods { diff --git a/extra_tests/snippets/stdlib_typing.py b/extra_tests/snippets/stdlib_typing.py index ddc30b684..681790abd 100644 --- a/extra_tests/snippets/stdlib_typing.py +++ b/extra_tests/snippets/stdlib_typing.py @@ -8,3 +8,7 @@ def abort_signal_handler( fn: Callable[[], Awaitable[T]], on_abort: Callable[[], None] | None = None ) -> T: pass + + +# Ensure PEP 604 unions work with typing.Callable aliases. +TracebackFilter = bool | Callable[[int], int]