diff --git a/vm/src/builtins/mod.rs b/vm/src/builtins/mod.rs index b199b89f8..d7478b180 100644 --- a/vm/src/builtins/mod.rs +++ b/vm/src/builtins/mod.rs @@ -70,7 +70,7 @@ pub(crate) mod staticmethod; pub use staticmethod::PyStaticMethod; pub(crate) mod traceback; pub use traceback::PyTraceback; -pub(crate) mod tuple; +pub mod tuple; pub use tuple::PyTuple; pub(crate) mod weakproxy; pub use weakproxy::PyWeakProxy; diff --git a/vm/src/builtins/tuple.rs b/vm/src/builtins/tuple.rs index 91ce4050d..a47d8c69a 100644 --- a/vm/src/builtins/tuple.rs +++ b/vm/src/builtins/tuple.rs @@ -1,12 +1,14 @@ use crossbeam_utils::atomic::AtomicCell; use std::fmt; +use std::marker::PhantomData; use super::pytype::PyTypeRef; use crate::common::hash::PyHash; use crate::function::OptionalArg; use crate::pyobject::{ self, BorrowValue, Either, IdProtocol, IntoPyObject, PyArithmaticValue, PyClassImpl, - PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TypeProtocol, + PyComparisonValue, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TransmuteFromObject, + TryFromObject, TypeProtocol, }; use crate::sequence::{self, SimpleSeq}; use crate::sliceable::PySliceableSequence; @@ -262,7 +264,7 @@ impl Iterable for PyTuple { #[pyclass(module = false, name = "tuple_iterator")] #[derive(Debug)] -pub struct PyTupleIterator { +pub(crate) struct PyTupleIterator { position: AtomicCell, tuple: PyTupleRef, } @@ -287,9 +289,35 @@ impl PyIter for PyTupleIterator { } } -pub fn init(context: &PyContext) { - let tuple_type = &context.types.tuple_type; - PyTuple::extend_class(context, tuple_type); - +pub(crate) fn init(context: &PyContext) { + PyTuple::extend_class(context, &context.types.tuple_type); PyTupleIterator::extend_class(context, &context.types.tuple_iterator_type); } + +pub struct PyTupleTyped { + // SAFETY INVARIANT: T must be repr(transparent) over PyObjectRef, and the + // elements must be logically valid when transmuted to T + tuple: PyTupleRef, + _marker: PhantomData>, +} + +impl TryFromObject for PyTupleTyped { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + let tuple = PyTupleRef::try_from_object(vm, obj)?; + for elem in tuple.borrow_value() { + T::check(vm, elem)? + } + // SAFETY: the contract of TransmuteFromObject upholds the variant on `tuple` + Ok(Self { + tuple, + _marker: PhantomData, + }) + } +} + +impl<'a, T: 'a> BorrowValue<'a> for PyTupleTyped { + type Borrowed = &'a [T]; + fn borrow_value(&'a self) -> Self::Borrowed { + unsafe { &*(self.tuple.borrow_value() as *const [PyObjectRef] as *const [T]) } + } +} diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 083b47972..03dc97d71 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -785,6 +785,40 @@ pub trait TryFromObject: Sized { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult; } +/// Marks a type that has the exact same layout as PyObjectRef, e.g. a type that is +/// `repr(transparent)` over PyObjectRef. +/// +/// # Safety +/// Can only be implemented for types that are `repr(transparent)` over a PyObjectRef `obj`, +/// and logically valid so long as `check(vm, obj)` returns `Ok(())` +pub unsafe trait TransmuteFromObject: Sized { + fn check(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult<()>; +} + +unsafe impl TransmuteFromObject for PyRef { + fn check(vm: &VirtualMachine, obj: &PyObjectRef) -> PyResult<()> { + let class = T::class(vm); + if obj.isinstance(class) { + if obj.payload_is::() { + Ok(()) + } else { + Err(vm.new_runtime_error(format!( + "Unexpected payload '{}' for type '{}'", + class.name, + obj.class().name, + ))) + } + } else { + let expected_type = &class.name; + let actual_type = &obj.class().name; + Err(vm.new_type_error(format!( + "Expected type '{}', not '{}'", + expected_type, actual_type, + ))) + } + } +} + pub trait IntoPyRef { fn into_pyref(self, vm: &VirtualMachine) -> PyRef; }