diff --git a/vm/src/builtins/memory.rs b/vm/src/builtins/memory.rs index c14599e84d..2fccad238c 100644 --- a/vm/src/builtins/memory.rs +++ b/vm/src/builtins/memory.rs @@ -1,19 +1,20 @@ use super::{ PyBytes, PyBytesRef, PyInt, PyListRef, PySlice, PyStr, PyStrRef, PyTuple, PyTupleRef, PyTypeRef, }; -use crate::common::{ - borrow::{BorrowedValue, BorrowedValueMut}, - hash::PyHash, - lock::OnceCell, -}; use crate::{ bytesinner::bytes_to_hex, + common::{ + borrow::{BorrowedValue, BorrowedValueMut}, + hash::PyHash, + lock::OnceCell, + static_cell, + }, function::{FuncArgs, IntoPyObject, OptionalArg}, protocol::{BufferDescriptor, BufferMethods, PyBuffer, PyMappingMethods, VecBuffer}, sequence::SequenceOp, sliceable::wrap_index, stdlib::pystruct::FormatSpec, - types::{AsBuffer, AsMapping, Comparable, Constructor, Hashable, PyComparisonOp}, + types::{AsBuffer, AsMapping, AsSequence, Comparable, Constructor, Hashable, PyComparisonOp}, utils::Either, IdProtocol, PyClassImpl, PyComparisonValue, PyContext, PyObject, PyObjectRef, PyObjectView, PyObjectWrap, PyRef, PyResult, PyValue, TryFromBorrowedObject, TryFromObject, TypeProtocol, @@ -21,7 +22,7 @@ use crate::{ }; use crossbeam_utils::atomic::AtomicCell; use itertools::Itertools; -use std::{cmp::Ordering, fmt::Debug, mem::ManuallyDrop, ops::Range}; +use std::{borrow::Cow, cmp::Ordering, fmt::Debug, mem::ManuallyDrop, ops::Range}; #[derive(FromArgs)] pub struct PyMemoryViewNewArgs { @@ -977,6 +978,22 @@ impl AsMapping for PyMemoryView { } } +impl AsSequence for PyMemoryView { + fn as_sequence(_zelf: &PyObjectView, _vm: &VirtualMachine) -> Cow<'static, PySequenceMethods> { + static_cell! { + static METHODS: PySequenceMethods; + } + Cow::Borrowed(METHODS.get_or_init(|| PySequenceMethods { + length: Some(|zelf, vm| zelf.payload::().unwrap().len(vm)), + item: Some(|zelf, i, vm| { + let zelf = zelf.clone().downcast::().unwrap(); + zelf.getitem_by_idx(i, vm) + }), + ..Default::default() + })) + } +} + impl Comparable for PyMemoryView { fn cmp( zelf: &crate::PyObjectView, diff --git a/vm/src/protocol/mod.rs b/vm/src/protocol/mod.rs index 24c88c31da..9415bb455e 100644 --- a/vm/src/protocol/mod.rs +++ b/vm/src/protocol/mod.rs @@ -2,7 +2,9 @@ mod buffer; mod iter; mod mapping; mod object; +mod sequence; pub use buffer::{BufferDescriptor, BufferMethods, BufferResizeGuard, PyBuffer, VecBuffer}; pub use iter::{PyIter, PyIterIter, PyIterReturn}; pub use mapping::{PyMapping, PyMappingMethods}; +pub use sequence::{PySequence, PySequenceMethods}; diff --git a/vm/src/protocol/sequence.rs b/vm/src/protocol/sequence.rs new file mode 100644 index 0000000000..5ad174f8e0 --- /dev/null +++ b/vm/src/protocol/sequence.rs @@ -0,0 +1,55 @@ +use std::borrow::{Borrow, Cow}; + +use crate::{IdProtocol, PyObjectRef, PyResult, TypeProtocol, VirtualMachine}; + +// Sequence Protocol +// https://docs.python.org/3/c-api/sequence.html + +#[allow(clippy::type_complexity)] +#[derive(Default, Clone)] +pub struct PySequenceMethods { + pub length: Option PyResult>, + pub concat: Option PyResult>, + pub repeat: Option PyResult>, + pub inplace_concat: + Option PyResult>, + pub inplace_repeat: Option PyResult>, + pub item: Option PyResult>, + pub ass_item: + Option, &VirtualMachine) -> PyResult<()>>, + pub contains: Option PyResult>, +} + +pub struct PySequence(PyObjectRef, Cow<'static, PySequenceMethods>); + +impl PySequence { + pub fn check(obj: &PyObjectRef, vm: &VirtualMachine) -> bool { + let cls = obj.class(); + if cls.is(&vm.ctx.types.dict_type) { + return false; + } + if let Some(f) = cls.mro_find_map(|x| x.slots.as_sequence.load()) { + return f(obj, vm).item.is_some(); + } + false + } + + pub fn from_object(vm: &VirtualMachine, obj: PyObjectRef) -> Option { + let cls = obj.class(); + if cls.is(&vm.ctx.types.dict_type) { + return None; + } + let f = cls.mro_find_map(|x| x.slots.as_sequence.load())?; + drop(cls); + let methods = f(&obj, vm); + if methods.item.is_some() { + Some(Self(obj, methods)) + } else { + None + } + } + + pub fn methods(&self) -> &PySequenceMethods { + self.1.borrow() + } +} diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 9a50b8ddfc..b514a15e3f 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -1,4 +1,5 @@ use crate::common::{hash::PyHash, lock::PyRwLock}; +use crate::PyArithmeticValue; use crate::{ builtins::{PyInt, PyStrRef, PyType, PyTypeRef}, function::{FromArgs, FuncArgs, IntoPyResult, OptionalArg}, @@ -9,6 +10,7 @@ use crate::{ }; use crossbeam_utils::atomic::AtomicCell; use num_traits::ToPrimitive; +use std::borrow::Cow; use std::cmp::Ordering; // The corresponding field in CPython is `tp_` prefixed. @@ -23,7 +25,7 @@ pub struct PyTypeSlots { // Method suites for standard classes // tp_as_number - // tp_as_sequence + pub as_sequence: AtomicCell>, pub as_mapping: AtomicCell>, // More standard operations (here for binary compatibility) @@ -151,19 +153,22 @@ pub(crate) type DescrSetFunc = fn(PyObjectRef, PyObjectRef, Option, &VirtualMachine) -> PyResult<()>; pub(crate) type NewFunc = fn(PyTypeRef, FuncArgs, &VirtualMachine) -> PyResult; pub(crate) type DelFunc = fn(&PyObject, &VirtualMachine) -> PyResult<()>; +pub(crate) type AsSequenceFunc = + fn(&PyObject, &VirtualMachine) -> Cow<'static, PySequenceMethods>; + +macro_rules! then_some_closure { + ($cond:expr, $closure:expr) => { + if $cond { + Some($closure) + } else { + None + } + }; +} pub use crate::builtins::object::{generic_getattr, generic_setattr}; fn as_mapping_wrapper(zelf: &PyObject, _vm: &VirtualMachine) -> PyMappingMethods { - macro_rules! then_some_closure { - ($cond:expr, $closure:expr) => { - if $cond { - Some($closure) - } else { - None - } - }; - } PyMappingMethods { length: then_some_closure!(zelf.has_class_attr("__len__"), |mapping, vm| { vm.call_special_method(mapping.obj.to_owned(), "__len__", ()) @@ -202,6 +207,90 @@ fn as_mapping_wrapper(zelf: &PyObject, _vm: &VirtualMachine) -> PyMappingMethods } } +fn as_sequence_wrapper( + zelf: &PyObject, + _vm: &VirtualMachine, +) -> Cow<'static, PySequenceMethods> { + Cow::Owned(PySequenceMethods { + length: then_some_closure!(zelf.has_class_attr("__len__"), |zelf, vm| { + vm.obj_len_opt(zelf).unwrap() + }), + concat: then_some_closure!(zelf.has_class_attr("__add__"), |zelf, other, vm| { + if PySequence::check(zelf, vm) && PySequence::check(other, vm) { + let ret = vm.call_special_method(zelf.clone(), "__add__", (other.clone(),))?; + if let PyArithmeticValue::Implemented(obj) = PyArithmeticValue::from_object(vm, ret) + { + return Ok(obj); + } + } + Err(vm.new_type_error(format!("'{}' object can't be concatenated", zelf))) + }), + repeat: then_some_closure!(zelf.has_class_attr("__mul__"), |zelf, n, vm| { + if PySequence::check(zelf, vm) { + let ret = + vm.call_special_method(zelf.clone(), "__mul__", (n.into_pyobject(vm),))?; + if let PyArithmeticValue::Implemented(obj) = PyArithmeticValue::from_object(vm, ret) + { + return Ok(obj); + } + } + Err(vm.new_type_error(format!("'{}' object can't be repeated", zelf))) + }), + inplace_concat: then_some_closure!( + zelf.has_class_attr("__iadd__") || zelf.has_class_attr("__add__"), + |zelf, other, vm| { + if PySequence::check(&zelf, vm) && PySequence::check(other, vm) { + if let Ok(f) = vm.get_special_method(zelf.clone(), "__iadd__")? { + let ret = f.invoke((other.clone(),), vm)?; + if let PyArithmeticValue::Implemented(obj) = + PyArithmeticValue::from_object(vm, ret) + { + return Ok(obj); + } + } + if let Ok(f) = vm.get_special_method(zelf.clone(), "__add__")? { + let ret = f.invoke((other.clone(),), vm)?; + if let PyArithmeticValue::Implemented(obj) = + PyArithmeticValue::from_object(vm, ret) + { + return Ok(obj); + } + } + } + Err(vm.new_type_error(format!("'{}' object can't be concatenated", zelf))) + } + ), + inplace_repeat: then_some_closure!( + zelf.has_class_attr("__imul__") || zelf.has_class_attr("__mul__"), + |zelf, n, vm| { + if PySequence::check(&zelf, vm) { + if let Ok(f) = vm.get_special_method(zelf.clone(), "__imul__")? { + let ret = f.invoke((n.into_pyobject(vm),), vm)?; + if let PyArithmeticValue::Implemented(obj) = + PyArithmeticValue::from_object(vm, ret) + { + return Ok(obj); + } + } + if let Ok(f) = vm.get_special_method(zelf.clone(), "__mul__")? { + let ret = f.invoke((n.into_pyobject(vm),), vm)?; + if let PyArithmeticValue::Implemented(obj) = + PyArithmeticValue::from_object(vm, ret) + { + return Ok(obj); + } + } + } + Err(vm.new_type_error(format!("'{}' object can't be repeated", zelf))) + } + ), + item: None, + ass_item: None, + // TODO: IterSearch + contains: None, + }) +} + fn hash_wrapper(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { let hash_obj = vm.call_special_method(zelf.to_owned(), "__hash__", ())?; match hash_obj.payload_if_subclass::(vm) { @@ -302,7 +391,10 @@ impl PyType { match name { "__len__" | "__getitem__" | "__setitem__" | "__delitem__" => { update_slot!(as_mapping, as_mapping_wrapper); - // TODO: need to update sequence protocol too + update_slot!(as_sequence, as_sequence_wrapper); + } + "__add__" | "__iadd__" | "__mul__" | "__imul__" => { + update_slot!(as_sequence, as_sequence_wrapper); } "__hash__" => { update_slot!(hash, hash_wrapper); @@ -791,6 +883,20 @@ pub trait AsMapping: PyValue { } } +#[pyimpl] +pub trait AsSequence: PyValue { + #[inline] + #[pyslot] + fn slot_as_sequence( + zelf: &PyObject, + vm: &VirtualMachine, + ) -> Cow<'static, PySequenceMethods> { + let zelf = unsafe { zelf.downcast_unchecked_ref::() }; + Self::as_sequence(zelf, vm) + } + fn as_sequence(zelf: &PyObjectView, vm: &VirtualMachine) -> Cow<'static, PySequenceMethods>; +} + #[pyimpl] pub trait Iterable: PyValue { #[pyslot]