diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index 2e4678d55..57bc1428b 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -15,7 +15,7 @@ use crate::{ format::{format, format_map}, function::{ArgIterable, ArgSize, FuncArgs, OptionalArg, OptionalOption, PyComparisonValue}, intern::PyInterned, - object::{Traverse, TraverseFn}, + object::{MaybeTraverse, Traverse, TraverseFn}, protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods}, sequence::SequenceExt, sliceable::{SequenceIndex, SliceableSequenceOp}, @@ -64,6 +64,9 @@ impl<'a> TryFromBorrowedObject<'a> for &'a Wtf8 { } } +pub type PyStrRef = PyRef; +pub type PyUtf8StrRef = PyRef; + #[pyclass(module = false, name = "str")] pub struct PyStr { data: StrData, @@ -80,30 +83,6 @@ impl fmt::Debug for PyStr { } } -#[repr(transparent)] -#[derive(Debug)] -pub struct PyUtf8Str(PyStr); - -// TODO: Remove this Deref which may hide missing optimized methods of PyUtf8Str -impl std::ops::Deref for PyUtf8Str { - type Target = PyStr; - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl PyUtf8Str { - /// Returns the underlying string slice. - pub fn as_str(&self) -> &str { - debug_assert!( - self.0.is_utf8(), - "PyUtf8Str invariant violated: inner string is not valid UTF-8" - ); - // Safety: This is safe because the type invariant guarantees UTF-8 validity. - unsafe { self.0.to_str().unwrap_unchecked() } - } -} - impl AsRef for PyStr { #[track_caller] // <- can remove this once it doesn't panic fn as_ref(&self) -> &str { @@ -241,8 +220,6 @@ impl Default for PyStr { } } -pub type PyStrRef = PyRef; - impl fmt::Display for PyStr { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -374,7 +351,7 @@ impl Constructor for PyStr { type Args = StrArgs; fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult { - let string: PyStrRef = match args.object { + let string: PyRef = match args.object { OptionalArg::Present(input) => { if let OptionalArg::Present(enc) = args.encoding { vm.state.codec_registry.decode_text( @@ -458,7 +435,7 @@ impl PyStr { self.data.as_str() } - pub fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> { + fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> { if self.is_utf8() { Ok(()) } else { @@ -531,6 +508,22 @@ impl PyStr { .mul(vm, value) .map(|x| Self::from(unsafe { Wtf8Buf::from_bytes_unchecked(x) }).into_ref(&vm.ctx)) } + + pub fn try_as_utf8<'a>(&'a self, vm: &VirtualMachine) -> PyResult<&'a PyUtf8Str> { + // Check if the string contains surrogates + self.ensure_valid_utf8(vm)?; + // If no surrogates, we can safely cast to PyStr + Ok(unsafe { &*(self as *const _ as *const PyUtf8Str) }) + } +} + +impl Py { + pub fn try_as_utf8<'a>(&'a self, vm: &VirtualMachine) -> PyResult<&'a Py> { + // Check if the string contains surrogates + self.ensure_valid_utf8(vm)?; + // If no surrogates, we can safely cast to PyStr + Ok(unsafe { &*(self as *const _ as *const Py) }) + } } #[pyclass( @@ -980,7 +973,11 @@ impl PyStr { } #[pymethod(name = "__format__")] - fn __format__(zelf: PyRef, spec: PyStrRef, vm: &VirtualMachine) -> PyResult { + fn __format__( + zelf: PyRef, + spec: PyStrRef, + vm: &VirtualMachine, + ) -> PyResult> { let spec = spec.as_str(); if spec.is_empty() { return if zelf.class().is(vm.ctx.types.str_type) { @@ -989,7 +986,7 @@ impl PyStr { zelf.as_object().str(vm) }; } - + let zelf = zelf.try_into_utf8(vm)?; let s = FormatSpec::parse(spec) .and_then(|format_spec| { format_spec.format_string(&CharLenStr(zelf.as_str(), zelf.char_len())) @@ -1351,8 +1348,12 @@ impl PyStr { } #[pymethod] - fn expandtabs(&self, args: anystr::ExpandTabsArgs) -> String { - rustpython_common::str::expandtabs(self.as_str(), args.tabsize()) + fn expandtabs(&self, args: anystr::ExpandTabsArgs, vm: &VirtualMachine) -> PyResult { + // TODO: support WTF-8 + Ok(rustpython_common::str::expandtabs( + self.try_as_utf8(vm)?.as_str(), + args.tabsize(), + )) } #[pymethod] @@ -1480,20 +1481,6 @@ impl PyStr { } } -struct CharLenStr<'a>(&'a str, usize); -impl std::ops::Deref for CharLenStr<'_> { - type Target = str; - - fn deref(&self) -> &Self::Target { - self.0 - } -} -impl crate::common::format::CharLen for CharLenStr<'_> { - fn char_len(&self) -> usize { - self.1 - } -} - #[pyclass] impl PyRef { #[pymethod] @@ -1504,7 +1491,7 @@ impl PyRef { } } -impl PyStrRef { +impl PyRef { pub fn is_empty(&self) -> bool { (**self).is_empty() } @@ -1526,6 +1513,20 @@ impl PyStrRef { } } +struct CharLenStr<'a>(&'a str, usize); +impl std::ops::Deref for CharLenStr<'_> { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.0 + } +} +impl crate::common::format::CharLen for CharLenStr<'_> { + fn char_len(&self) -> usize { + self.1 + } +} + impl Representable for PyStr { #[inline] fn repr_str(zelf: &Py, vm: &VirtualMachine) -> PyResult { @@ -1941,6 +1942,170 @@ impl AnyStrWrapper for PyStrRef { } } +#[repr(transparent)] +#[derive(Debug)] +pub struct PyUtf8Str(PyStr); + +impl fmt::Display for PyUtf8Str { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl MaybeTraverse for PyUtf8Str { + fn try_traverse(&self, traverse_fn: &mut TraverseFn<'_>) { + self.0.try_traverse(traverse_fn); + } +} + +impl PyPayload for PyUtf8Str { + #[inline] + fn class(ctx: &Context) -> &'static Py { + ctx.types.str_type + } + + fn payload_type_id() -> std::any::TypeId { + std::any::TypeId::of::() + } + + fn downcastable_from(obj: &PyObject) -> bool { + obj.typeid() == Self::payload_type_id() && { + // SAFETY: we know the object is a PyStr in this context + let wtf8 = unsafe { obj.downcast_unchecked_ref::() }; + wtf8.is_utf8() + } + } + + fn try_downcast_from(obj: &PyObject, vm: &VirtualMachine) -> PyResult<()> { + let str = obj.try_downcast_ref::(vm)?; + str.ensure_valid_utf8(vm) + } +} + +impl<'a> From<&'a AsciiStr> for PyUtf8Str { + fn from(s: &'a AsciiStr) -> Self { + s.to_owned().into() + } +} + +impl From for PyUtf8Str { + fn from(s: AsciiString) -> Self { + s.into_boxed_ascii_str().into() + } +} + +impl From> for PyUtf8Str { + fn from(s: Box) -> Self { + let data = StrData::from(s); + unsafe { Self::from_str_data_unchecked(data) } + } +} + +impl From for PyUtf8Str { + fn from(ch: AsciiChar) -> Self { + AsciiString::from(ch).into() + } +} + +impl<'a> From<&'a str> for PyUtf8Str { + fn from(s: &'a str) -> Self { + s.to_owned().into() + } +} + +impl From for PyUtf8Str { + fn from(s: String) -> Self { + s.into_boxed_str().into() + } +} + +impl From for PyUtf8Str { + fn from(ch: char) -> Self { + let data = StrData::from(ch); + unsafe { Self::from_str_data_unchecked(data) } + } +} + +impl<'a> From> for PyUtf8Str { + fn from(s: std::borrow::Cow<'a, str>) -> Self { + s.into_owned().into() + } +} + +impl From> for PyUtf8Str { + #[inline] + fn from(value: Box) -> Self { + let data = StrData::from(value); + unsafe { Self::from_str_data_unchecked(data) } + } +} + +impl AsRef for PyUtf8Str { + #[inline] + fn as_ref(&self) -> &Wtf8 { + self.0.as_wtf8() + } +} + +impl AsRef for PyUtf8Str { + #[inline] + fn as_ref(&self) -> &str { + self.0.as_str() + } +} + +impl PyUtf8Str { + // Create a new `PyUtf8Str` from `StrData` without validation. + // This function must be only used in this module to create conversions. + // # Safety: must be called with a valid UTF-8 string data. + unsafe fn from_str_data_unchecked(data: StrData) -> Self { + Self(PyStr::from(data)) + } + + /// Returns the underlying string slice. + pub fn as_str(&self) -> &str { + debug_assert!( + self.0.is_utf8(), + "PyUtf8Str invariant violated: inner string is not valid UTF-8" + ); + // Safety: This is safe because the type invariant guarantees UTF-8 validity. + unsafe { self.0.to_str().unwrap_unchecked() } + } + + #[inline] + pub fn byte_len(&self) -> usize { + self.0.byte_len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + #[inline] + pub fn char_len(&self) -> usize { + self.0.char_len() + } +} + +impl Py { + /// Upcast to PyStr. + pub fn as_pystr(&self) -> &Py { + unsafe { + // Safety: PyUtf8Str is a wrapper around PyStr, so this cast is safe. + &*(self as *const Self as *const Py) + } + } +} + +impl PartialEq for PyUtf8Str { + fn eq(&self, other: &Self) -> bool { + self.as_str() == other.as_str() + } +} +impl Eq for PyUtf8Str {} + impl AnyStrContainer for String { fn new() -> Self { Self::new() @@ -2302,7 +2467,8 @@ impl std::fmt::Display for PyStrInterned { impl AsRef for PyStrInterned { #[inline(always)] fn as_ref(&self) -> &str { - self.as_str() + self.to_str() + .expect("Interned PyStr should always be valid UTF-8") } } diff --git a/vm/src/convert/try_from.rs b/vm/src/convert/try_from.rs index 3fda682d4..4f921e9c5 100644 --- a/vm/src/convert/try_from.rs +++ b/vm/src/convert/try_from.rs @@ -78,12 +78,12 @@ where #[inline] fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { let class = T::class(&vm.ctx); - let result = if obj.fast_isinstance(class) { - obj.downcast() + if obj.fast_isinstance(class) { + T::try_downcast_from(&obj, vm)?; + Ok(unsafe { obj.downcast_unchecked() }) } else { - Err(obj) - }; - result.map_err(|obj| vm.new_downcast_type_error(class, &obj)) + Err(vm.new_downcast_type_error(class, &obj)) + } } } diff --git a/vm/src/object/core.rs b/vm/src/object/core.rs index 54a6657a9..57576ce70 100644 --- a/vm/src/object/core.rs +++ b/vm/src/object/core.rs @@ -448,7 +448,7 @@ impl PyInner { let member_count = typ.slots.member_count; Box::new(Self { ref_count: RefCount::new(), - typeid: TypeId::of::(), + typeid: T::payload_type_id(), vtable: PyObjVTable::of::(), typ: PyAtomicRef::from(typ), dict: dict.map(InstanceDict::new), @@ -541,6 +541,11 @@ impl PyObjectRef { } } + pub fn try_downcast(self, vm: &VirtualMachine) -> PyResult> { + T::try_downcast_from(&self, vm)?; + Ok(unsafe { self.downcast_unchecked() }) + } + /// Force to downcast this reference to a subclass. /// /// # Safety @@ -720,10 +725,24 @@ impl PyObject { } } + #[inline] + pub(crate) fn typeid(&self) -> TypeId { + self.0.typeid + } + /// Check if this object can be downcast to T. #[inline(always)] pub fn downcastable(&self) -> bool { - self.0.typeid == T::payload_type_id() + T::downcastable_from(self) + } + + /// Attempt to downcast this reference to a subclass. + pub fn try_downcast_ref<'a, T: PyObjectPayload>( + &'a self, + vm: &VirtualMachine, + ) -> PyResult<&'a Py> { + T::try_downcast_from(self, vm)?; + Ok(unsafe { self.downcast_unchecked_ref::() }) } /// Attempt to downcast this reference to a subclass. diff --git a/vm/src/object/payload.rs b/vm/src/object/payload.rs index f223af6e9..0b7bfe0dc 100644 --- a/vm/src/object/payload.rs +++ b/vm/src/object/payload.rs @@ -1,6 +1,6 @@ use crate::object::{MaybeTraverse, Py, PyObjectRef, PyRef, PyResult}; use crate::{ - PyRefExact, + PyObject, PyRefExact, builtins::{PyBaseExceptionRef, PyType, PyTypeRef}, types::PyTypeFlags, vm::{Context, VirtualMachine}, @@ -23,6 +23,31 @@ pub trait PyPayload: fn payload_type_id() -> std::any::TypeId { std::any::TypeId::of::() } + + /// # Safety: this function should only be called if `payload_type_id` matches the type of `obj`. + #[inline] + fn downcastable_from(obj: &PyObject) -> bool { + obj.typeid() == Self::payload_type_id() + } + + fn try_downcast_from(obj: &PyObject, vm: &VirtualMachine) -> PyResult<()> { + if Self::downcastable_from(obj) { + return Ok(()); + } + + #[cold] + fn raise_downcast_type_error( + vm: &VirtualMachine, + class: &Py, + obj: &PyObject, + ) -> PyBaseExceptionRef { + vm.new_downcast_type_error(class, obj) + } + + let class = Self::class(&vm.ctx); + Err(raise_downcast_type_error(vm, class, obj)) + } + fn class(ctx: &Context) -> &'static Py; #[inline]