diff --git a/vm/src/builtins/memory.rs b/vm/src/builtins/memory.rs index 2fccad238c..f54e3ec8cf 100644 --- a/vm/src/builtins/memory.rs +++ b/vm/src/builtins/memory.rs @@ -979,7 +979,10 @@ impl AsMapping for PyMemoryView { } impl AsSequence for PyMemoryView { - fn as_sequence(_zelf: &PyObjectView, _vm: &VirtualMachine) -> Cow<'static, PySequenceMethods> { + fn as_sequence( + _zelf: &PyObjectView, + _vm: &VirtualMachine, + ) -> Cow<'static, PySequenceMethods> { static_cell! { static METHODS: PySequenceMethods; } diff --git a/vm/src/protocol/mod.rs b/vm/src/protocol/mod.rs index 9415bb455e..ea2c15f870 100644 --- a/vm/src/protocol/mod.rs +++ b/vm/src/protocol/mod.rs @@ -2,7 +2,7 @@ mod buffer; mod iter; mod mapping; mod object; -mod sequence; +pub(crate) mod sequence; pub use buffer::{BufferDescriptor, BufferMethods, BufferResizeGuard, PyBuffer, VecBuffer}; pub use iter::{PyIter, PyIterIter, PyIterReturn}; diff --git a/vm/src/protocol/sequence.rs b/vm/src/protocol/sequence.rs index 5ad174f8e0..a6ed98f983 100644 --- a/vm/src/protocol/sequence.rs +++ b/vm/src/protocol/sequence.rs @@ -1,6 +1,10 @@ use std::borrow::{Borrow, Cow}; -use crate::{IdProtocol, PyObjectRef, PyResult, TypeProtocol, VirtualMachine}; +use crate::builtins::PySlice; +use crate::function::IntoPyObject; +use crate::{ + IdProtocol, PyArithmeticValue, PyObjectRef, PyResult, PyValue, TypeProtocol, VirtualMachine, +}; // Sequence Protocol // https://docs.python.org/3/c-api/sequence.html @@ -9,18 +13,20 @@ use crate::{IdProtocol, PyObjectRef, PyResult, TypeProtocol, VirtualMachine}; #[derive(Default, Clone)] pub struct PySequenceMethods { pub length: Option PyResult>, - pub concat: Option PyResult>, - pub repeat: Option PyResult>, - pub inplace_concat: - Option PyResult>, - pub inplace_repeat: Option PyResult>, - pub item: Option PyResult>, + pub concat: Option PyResult>, + pub repeat: Option PyResult>, + pub inplace_concat: Option PyResult>, + pub inplace_repeat: Option PyResult>, + pub item: Option PyResult>, pub ass_item: - Option, &VirtualMachine) -> PyResult<()>>, + Option, &VirtualMachine) -> PyResult<()>>, pub contains: Option PyResult>, } -pub struct PySequence(PyObjectRef, Cow<'static, PySequenceMethods>); +pub struct PySequence { + obj: PyObjectRef, + methods: Cow<'static, PySequenceMethods>, +} impl PySequence { pub fn check(obj: &PyObjectRef, vm: &VirtualMachine) -> bool { @@ -28,10 +34,9 @@ impl PySequence { if cls.is(&vm.ctx.types.dict_type) { return false; } - if let Some(f) = cls.mro_find_map(|x| x.slots.as_sequence.load()) { - return f(obj, vm).item.is_some(); - } - false + cls.mro_find_map(|x| x.slots.as_sequence.load()) + .map(|f| f(obj, vm).item.is_some()) + .unwrap_or(false) } pub fn from_object(vm: &VirtualMachine, obj: PyObjectRef) -> Option { @@ -43,13 +48,198 @@ impl PySequence { drop(cls); let methods = f(&obj, vm); if methods.item.is_some() { - Some(Self(obj, methods)) + Some(Self { obj, methods }) } else { None } } pub fn methods(&self) -> &PySequenceMethods { - self.1.borrow() + self.methods.borrow() + } + + pub fn length(&self, vm: &VirtualMachine) -> PyResult { + if let Some(f) = self.methods().length { + f(&self.obj, vm) + } else { + Err(vm.new_type_error(format!("'{}' is not a sequence or has no len()", &self.obj))) + } + } + + pub fn concat(&self, other: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Some(f) = self.methods().concat { + return f(&self.obj, other, vm); + } + try_add_for_concat(&self.obj, other, vm) + } + + pub fn repeat(&self, n: usize, vm: &VirtualMachine) -> PyResult { + if let Some(f) = self.methods().repeat { + return f(&self.obj, n, vm); + } + try_mul_for_repeat(&self.obj, n, vm) + } + + pub fn inplace_concat(&self, other: &PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Some(f) = self.methods().inplace_concat { + return f(&self.obj, other, vm); + } + if let Some(f) = self.methods().concat { + return f(&self.obj, other, vm); + } + try_iadd_for_inplace_concat(&self.obj, other, vm) + } + + pub fn inplace_repeat(&self, n: usize, vm: &VirtualMachine) -> PyResult { + if let Some(f) = self.methods().inplace_repeat { + return f(&self.obj, n, vm); + } + if let Some(f) = self.methods().repeat { + return f(&self.obj, n, vm); + } + try_imul_for_inplace_repeat(&self.obj, n, vm) + } + + pub fn get_item(&self, i: isize, vm: &VirtualMachine) -> PyResult { + if let Some(f) = self.methods().item { + return f(&self.obj, i, vm); + } + Err(vm.new_type_error(format!( + "'{}' is not a sequence or does not support indexing", + &self.obj + ))) + } + + fn _ass_item(&self, i: isize, value: Option, vm: &VirtualMachine) -> PyResult<()> { + if let Some(f) = self.methods().ass_item { + return f(&self.obj, i, value, vm); + } + Err(vm.new_type_error(format!( + "'{}' is not a sequence or doesn't support item {}", + &self.obj, + if value.is_some() { + "assignment" + } else { + "deletion" + } + ))) + } + + pub fn set_item(&self, i: isize, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + self._ass_item(i, Some(value), vm) + } + + pub fn del_item(&self, i: isize, vm: &VirtualMachine) -> PyResult<()> { + self._ass_item(i, None, vm) + } + + pub fn get_slice(&self, start: isize, stop: isize, vm: &VirtualMachine) -> PyResult { + if let Some(f) = self.obj.class().mro_find_map(|x| x.slots.as_mapping.load()) { + let mp = f(&self.obj, vm); + if let Some(subscript) = mp.subscript { + let slice = PySlice { + start: Some(start.into_pyobject(vm)), + stop: stop.into_pyobject(vm), + step: None, + }; + + return subscript(self.obj.clone(), slice.into_object(vm), vm); + } + } + Err(vm.new_type_error(format!("'{}' object is unsliceable", &self.obj))) + } + + fn _ass_slice( + &self, + start: isize, + stop: isize, + value: Option, + vm: &VirtualMachine, + ) -> PyResult<()> { + let cls = self.obj.class(); + if let Some(f) = cls.mro_find_map(|x| x.slots.as_mapping.load()) { + drop(cls); + let mp = f(&self.obj, vm); + if let Some(ass_subscript) = mp.ass_subscript { + let slice = PySlice { + start: Some(start.into_pyobject(vm)), + stop: stop.into_pyobject(vm), + step: None, + }; + + return ass_subscript(self.obj.clone(), slice.into_object(vm), value, vm); + } + } + Err(vm.new_type_error(format!( + "'{}' object doesn't support slice {}", + &self.obj, + if value.is_some() { + "assignment" + } else { + "deletion" + } + ))) + } + + pub fn set_slice( + &self, + start: isize, + stop: isize, + value: PyObjectRef, + vm: &VirtualMachine, + ) -> PyResult<()> { + self._ass_slice(start, stop, Some(value), vm) + } + + pub fn del_slice(&self, start: isize, stop: isize, vm: &VirtualMachine) -> PyResult<()> { + self._ass_slice(start, stop, None, vm) } } + +pub(crate) fn try_add_for_concat( + a: &PyObjectRef, + b: &PyObjectRef, + vm: &VirtualMachine, +) -> PyResult { + if PySequence::check(b, vm) { + let ret = vm._add(a, b)?; + if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) { + return Ok(ret); + } + } + Err(vm.new_type_error(format!("'{}' object can't be concatenated", a))) +} + +pub(crate) fn try_mul_for_repeat(a: &PyObjectRef, n: usize, vm: &VirtualMachine) -> PyResult { + let ret = vm._mul(a, &n.into_pyobject(vm))?; + if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) { + return Ok(ret); + } + Err(vm.new_type_error(format!("'{}' object can't be repeated", a))) +} + +pub(crate) fn try_iadd_for_inplace_concat( + a: &PyObjectRef, + b: &PyObjectRef, + vm: &VirtualMachine, +) -> PyResult { + if PySequence::check(b, vm) { + let ret = vm._iadd(a, b)?; + if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) { + return Ok(ret); + } + } + Err(vm.new_type_error(format!("'{}' object can't be concatenated", a))) +} + +pub(crate) fn try_imul_for_inplace_repeat( + a: &PyObjectRef, + n: usize, + vm: &VirtualMachine, +) -> PyResult { + let ret = vm._imul(a, &n.into_pyobject(vm))?; + if let PyArithmeticValue::Implemented(ret) = PyArithmeticValue::from_object(vm, ret) { + return Ok(ret); + } + Err(vm.new_type_error(format!("'{}' object can't be repeated", a))) +} diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index b514a15e3f..40840b977c 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -1,5 +1,4 @@ use crate::common::{hash::PyHash, lock::PyRwLock}; -use crate::PyArithmeticValue; use crate::{ builtins::{PyInt, PyStrRef, PyType, PyTypeRef}, function::{FromArgs, FuncArgs, IntoPyResult, OptionalArg}, @@ -10,6 +9,7 @@ use crate::{ }; use crossbeam_utils::atomic::AtomicCell; use num_traits::ToPrimitive; +use rustpython_common::static_cell; use std::borrow::Cow; use std::cmp::Ordering; @@ -153,8 +153,7 @@ pub(crate) type DescrSetFunc = fn(PyObjectRef, PyObjectRef, Option, &VirtualMachine) -> PyResult<()>; pub(crate) type NewFunc = fn(PyTypeRef, FuncArgs, &VirtualMachine) -> PyResult; pub(crate) type DelFunc = fn(&PyObject, &VirtualMachine) -> PyResult<()>; -pub(crate) type AsSequenceFunc = - fn(&PyObject, &VirtualMachine) -> Cow<'static, PySequenceMethods>; +pub(crate) type AsSequenceFunc = fn(&PyObject, &VirtualMachine) -> Cow<'static, PySequenceMethods>; macro_rules! then_some_closure { ($cond:expr, $closure:expr) => { @@ -207,85 +206,46 @@ fn as_mapping_wrapper(zelf: &PyObject, _vm: &VirtualMachine) -> PyMappingMethods } } -fn as_sequence_wrapper( - zelf: &PyObject, - _vm: &VirtualMachine, -) -> Cow<'static, PySequenceMethods> { +fn as_sequence_wrapper(zelf: &PyObject, _vm: &VirtualMachine) -> Cow<'static, PySequenceMethods> { + static_cell! { + static EMPTY: PySequenceMethods; + } + if !zelf.has_class_attr("__getitem__") { + return Cow::Borrowed(EMPTY.get_or_init(PySequenceMethods::default)); + } + Cow::Owned(PySequenceMethods { length: then_some_closure!(zelf.has_class_attr("__len__"), |zelf, vm| { vm.obj_len_opt(zelf).unwrap() }), concat: then_some_closure!(zelf.has_class_attr("__add__"), |zelf, other, vm| { - if PySequence::check(zelf, vm) && PySequence::check(other, vm) { - let ret = vm.call_special_method(zelf.clone(), "__add__", (other.clone(),))?; - if let PyArithmeticValue::Implemented(obj) = PyArithmeticValue::from_object(vm, ret) - { - return Ok(obj); - } - } - Err(vm.new_type_error(format!("'{}' object can't be concatenated", zelf))) + try_add_for_concat(zelf, other, vm) }), repeat: then_some_closure!(zelf.has_class_attr("__mul__"), |zelf, n, vm| { - if PySequence::check(zelf, vm) { - let ret = - vm.call_special_method(zelf.clone(), "__mul__", (n.into_pyobject(vm),))?; - if let PyArithmeticValue::Implemented(obj) = PyArithmeticValue::from_object(vm, ret) - { - return Ok(obj); - } - } - Err(vm.new_type_error(format!("'{}' object can't be repeated", zelf))) + try_mul_for_repeat(zelf, n, vm) }), inplace_concat: then_some_closure!( zelf.has_class_attr("__iadd__") || zelf.has_class_attr("__add__"), - |zelf, other, vm| { - if PySequence::check(&zelf, vm) && PySequence::check(other, vm) { - if let Ok(f) = vm.get_special_method(zelf.clone(), "__iadd__")? { - let ret = f.invoke((other.clone(),), vm)?; - if let PyArithmeticValue::Implemented(obj) = - PyArithmeticValue::from_object(vm, ret) - { - return Ok(obj); - } - } - if let Ok(f) = vm.get_special_method(zelf.clone(), "__add__")? { - let ret = f.invoke((other.clone(),), vm)?; - if let PyArithmeticValue::Implemented(obj) = - PyArithmeticValue::from_object(vm, ret) - { - return Ok(obj); - } - } - } - Err(vm.new_type_error(format!("'{}' object can't be concatenated", zelf))) - } + |zelf, other, vm| { try_iadd_for_inplace_concat(zelf, other, vm) } ), inplace_repeat: then_some_closure!( zelf.has_class_attr("__imul__") || zelf.has_class_attr("__mul__"), - |zelf, n, vm| { - if PySequence::check(&zelf, vm) { - if let Ok(f) = vm.get_special_method(zelf.clone(), "__imul__")? { - let ret = f.invoke((n.into_pyobject(vm),), vm)?; - if let PyArithmeticValue::Implemented(obj) = - PyArithmeticValue::from_object(vm, ret) - { - return Ok(obj); - } - } - if let Ok(f) = vm.get_special_method(zelf.clone(), "__mul__")? { - let ret = f.invoke((n.into_pyobject(vm),), vm)?; - if let PyArithmeticValue::Implemented(obj) = - PyArithmeticValue::from_object(vm, ret) - { - return Ok(obj); - } - } - } - Err(vm.new_type_error(format!("'{}' object can't be repeated", zelf))) + |zelf, n, vm| { try_imul_for_inplace_repeat(zelf, n, vm) } + ), + item: Some(|zelf, i, vm| { + vm.call_special_method(zelf.clone(), "__getitem__", (i.into_pyobject(vm),)) + }), + ass_item: then_some_closure!( + zelf.has_class_attr("__setitem__") | zelf.has_class_attr("__delitem__"), + |zelf, i, value, vm| match value { + Some(value) => vm + .call_special_method(zelf.clone(), "__setitem__", (i.into_pyobject(vm), value),) + .map(|_| Ok(()))?, + None => vm + .call_special_method(zelf.clone(), "__delitem__", (i.into_pyobject(vm),)) + .map(|_| Ok(()))?, } ), - item: None, - ass_item: None, // TODO: IterSearch contains: None, }) @@ -887,14 +847,14 @@ pub trait AsMapping: PyValue { pub trait AsSequence: PyValue { #[inline] #[pyslot] - fn slot_as_sequence( - zelf: &PyObject, - vm: &VirtualMachine, - ) -> Cow<'static, PySequenceMethods> { + fn slot_as_sequence(zelf: &PyObject, vm: &VirtualMachine) -> Cow<'static, PySequenceMethods> { let zelf = unsafe { zelf.downcast_unchecked_ref::() }; Self::as_sequence(zelf, vm) } - fn as_sequence(zelf: &PyObjectView, vm: &VirtualMachine) -> Cow<'static, PySequenceMethods>; + fn as_sequence( + zelf: &PyObjectView, + vm: &VirtualMachine, + ) -> Cow<'static, PySequenceMethods>; } #[pyimpl]