downcastable_from

This commit is contained in:
Jeong YunWon
2025-07-30 10:28:04 +09:00
parent f402deef6d
commit 053cfeecce
4 changed files with 267 additions and 57 deletions

View File

@@ -15,7 +15,7 @@ use crate::{
format::{format, format_map},
function::{ArgIterable, ArgSize, FuncArgs, OptionalArg, OptionalOption, PyComparisonValue},
intern::PyInterned,
object::{Traverse, TraverseFn},
object::{MaybeTraverse, Traverse, TraverseFn},
protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
sequence::SequenceExt,
sliceable::{SequenceIndex, SliceableSequenceOp},
@@ -64,6 +64,9 @@ impl<'a> TryFromBorrowedObject<'a> for &'a Wtf8 {
}
}
pub type PyStrRef = PyRef<PyStr>;
pub type PyUtf8StrRef = PyRef<PyUtf8Str>;
#[pyclass(module = false, name = "str")]
pub struct PyStr {
data: StrData,
@@ -80,30 +83,6 @@ impl fmt::Debug for PyStr {
}
}
#[repr(transparent)]
#[derive(Debug)]
pub struct PyUtf8Str(PyStr);
// TODO: Remove this Deref which may hide missing optimized methods of PyUtf8Str
impl std::ops::Deref for PyUtf8Str {
type Target = PyStr;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl PyUtf8Str {
/// Returns the underlying string slice.
pub fn as_str(&self) -> &str {
debug_assert!(
self.0.is_utf8(),
"PyUtf8Str invariant violated: inner string is not valid UTF-8"
);
// Safety: This is safe because the type invariant guarantees UTF-8 validity.
unsafe { self.0.to_str().unwrap_unchecked() }
}
}
impl AsRef<str> for PyStr {
#[track_caller] // <- can remove this once it doesn't panic
fn as_ref(&self) -> &str {
@@ -241,8 +220,6 @@ impl Default for PyStr {
}
}
pub type PyStrRef = PyRef<PyStr>;
impl fmt::Display for PyStr {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -374,7 +351,7 @@ impl Constructor for PyStr {
type Args = StrArgs;
fn py_new(cls: PyTypeRef, args: Self::Args, vm: &VirtualMachine) -> PyResult {
let string: PyStrRef = match args.object {
let string: PyRef<PyStr> = match args.object {
OptionalArg::Present(input) => {
if let OptionalArg::Present(enc) = args.encoding {
vm.state.codec_registry.decode_text(
@@ -458,7 +435,7 @@ impl PyStr {
self.data.as_str()
}
pub fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> {
fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> {
if self.is_utf8() {
Ok(())
} else {
@@ -531,6 +508,22 @@ impl PyStr {
.mul(vm, value)
.map(|x| Self::from(unsafe { Wtf8Buf::from_bytes_unchecked(x) }).into_ref(&vm.ctx))
}
pub fn try_as_utf8<'a>(&'a self, vm: &VirtualMachine) -> PyResult<&'a PyUtf8Str> {
// Check if the string contains surrogates
self.ensure_valid_utf8(vm)?;
// If no surrogates, we can safely cast to PyStr
Ok(unsafe { &*(self as *const _ as *const PyUtf8Str) })
}
}
impl Py<PyStr> {
pub fn try_as_utf8<'a>(&'a self, vm: &VirtualMachine) -> PyResult<&'a Py<PyUtf8Str>> {
// Check if the string contains surrogates
self.ensure_valid_utf8(vm)?;
// If no surrogates, we can safely cast to PyStr
Ok(unsafe { &*(self as *const _ as *const Py<PyUtf8Str>) })
}
}
#[pyclass(
@@ -980,7 +973,11 @@ impl PyStr {
}
#[pymethod(name = "__format__")]
fn __format__(zelf: PyRef<Self>, spec: PyStrRef, vm: &VirtualMachine) -> PyResult<PyStrRef> {
fn __format__(
zelf: PyRef<PyStr>,
spec: PyStrRef,
vm: &VirtualMachine,
) -> PyResult<PyRef<PyStr>> {
let spec = spec.as_str();
if spec.is_empty() {
return if zelf.class().is(vm.ctx.types.str_type) {
@@ -989,7 +986,7 @@ impl PyStr {
zelf.as_object().str(vm)
};
}
let zelf = zelf.try_into_utf8(vm)?;
let s = FormatSpec::parse(spec)
.and_then(|format_spec| {
format_spec.format_string(&CharLenStr(zelf.as_str(), zelf.char_len()))
@@ -1351,8 +1348,12 @@ impl PyStr {
}
#[pymethod]
fn expandtabs(&self, args: anystr::ExpandTabsArgs) -> String {
rustpython_common::str::expandtabs(self.as_str(), args.tabsize())
fn expandtabs(&self, args: anystr::ExpandTabsArgs, vm: &VirtualMachine) -> PyResult<String> {
// TODO: support WTF-8
Ok(rustpython_common::str::expandtabs(
self.try_as_utf8(vm)?.as_str(),
args.tabsize(),
))
}
#[pymethod]
@@ -1480,20 +1481,6 @@ impl PyStr {
}
}
struct CharLenStr<'a>(&'a str, usize);
impl std::ops::Deref for CharLenStr<'_> {
type Target = str;
fn deref(&self) -> &Self::Target {
self.0
}
}
impl crate::common::format::CharLen for CharLenStr<'_> {
fn char_len(&self) -> usize {
self.1
}
}
#[pyclass]
impl PyRef<PyStr> {
#[pymethod]
@@ -1504,7 +1491,7 @@ impl PyRef<PyStr> {
}
}
impl PyStrRef {
impl PyRef<PyStr> {
pub fn is_empty(&self) -> bool {
(**self).is_empty()
}
@@ -1526,6 +1513,20 @@ impl PyStrRef {
}
}
struct CharLenStr<'a>(&'a str, usize);
impl std::ops::Deref for CharLenStr<'_> {
type Target = str;
fn deref(&self) -> &Self::Target {
self.0
}
}
impl crate::common::format::CharLen for CharLenStr<'_> {
fn char_len(&self) -> usize {
self.1
}
}
impl Representable for PyStr {
#[inline]
fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
@@ -1941,6 +1942,170 @@ impl AnyStrWrapper<AsciiStr> for PyStrRef {
}
}
#[repr(transparent)]
#[derive(Debug)]
pub struct PyUtf8Str(PyStr);
impl fmt::Display for PyUtf8Str {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl MaybeTraverse for PyUtf8Str {
fn try_traverse(&self, traverse_fn: &mut TraverseFn<'_>) {
self.0.try_traverse(traverse_fn);
}
}
impl PyPayload for PyUtf8Str {
#[inline]
fn class(ctx: &Context) -> &'static Py<PyType> {
ctx.types.str_type
}
fn payload_type_id() -> std::any::TypeId {
std::any::TypeId::of::<PyStr>()
}
fn downcastable_from(obj: &PyObject) -> bool {
obj.typeid() == Self::payload_type_id() && {
// SAFETY: we know the object is a PyStr in this context
let wtf8 = unsafe { obj.downcast_unchecked_ref::<PyStr>() };
wtf8.is_utf8()
}
}
fn try_downcast_from(obj: &PyObject, vm: &VirtualMachine) -> PyResult<()> {
let str = obj.try_downcast_ref::<PyStr>(vm)?;
str.ensure_valid_utf8(vm)
}
}
impl<'a> From<&'a AsciiStr> for PyUtf8Str {
fn from(s: &'a AsciiStr) -> Self {
s.to_owned().into()
}
}
impl From<AsciiString> for PyUtf8Str {
fn from(s: AsciiString) -> Self {
s.into_boxed_ascii_str().into()
}
}
impl From<Box<AsciiStr>> for PyUtf8Str {
fn from(s: Box<AsciiStr>) -> Self {
let data = StrData::from(s);
unsafe { Self::from_str_data_unchecked(data) }
}
}
impl From<AsciiChar> for PyUtf8Str {
fn from(ch: AsciiChar) -> Self {
AsciiString::from(ch).into()
}
}
impl<'a> From<&'a str> for PyUtf8Str {
fn from(s: &'a str) -> Self {
s.to_owned().into()
}
}
impl From<String> for PyUtf8Str {
fn from(s: String) -> Self {
s.into_boxed_str().into()
}
}
impl From<char> for PyUtf8Str {
fn from(ch: char) -> Self {
let data = StrData::from(ch);
unsafe { Self::from_str_data_unchecked(data) }
}
}
impl<'a> From<std::borrow::Cow<'a, str>> for PyUtf8Str {
fn from(s: std::borrow::Cow<'a, str>) -> Self {
s.into_owned().into()
}
}
impl From<Box<str>> for PyUtf8Str {
#[inline]
fn from(value: Box<str>) -> Self {
let data = StrData::from(value);
unsafe { Self::from_str_data_unchecked(data) }
}
}
impl AsRef<Wtf8> for PyUtf8Str {
#[inline]
fn as_ref(&self) -> &Wtf8 {
self.0.as_wtf8()
}
}
impl AsRef<str> for PyUtf8Str {
#[inline]
fn as_ref(&self) -> &str {
self.0.as_str()
}
}
impl PyUtf8Str {
// Create a new `PyUtf8Str` from `StrData` without validation.
// This function must be only used in this module to create conversions.
// # Safety: must be called with a valid UTF-8 string data.
unsafe fn from_str_data_unchecked(data: StrData) -> Self {
Self(PyStr::from(data))
}
/// Returns the underlying string slice.
pub fn as_str(&self) -> &str {
debug_assert!(
self.0.is_utf8(),
"PyUtf8Str invariant violated: inner string is not valid UTF-8"
);
// Safety: This is safe because the type invariant guarantees UTF-8 validity.
unsafe { self.0.to_str().unwrap_unchecked() }
}
#[inline]
pub fn byte_len(&self) -> usize {
self.0.byte_len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[inline]
pub fn char_len(&self) -> usize {
self.0.char_len()
}
}
impl Py<PyUtf8Str> {
/// Upcast to PyStr.
pub fn as_pystr(&self) -> &Py<PyStr> {
unsafe {
// Safety: PyUtf8Str is a wrapper around PyStr, so this cast is safe.
&*(self as *const Self as *const Py<PyStr>)
}
}
}
impl PartialEq for PyUtf8Str {
fn eq(&self, other: &Self) -> bool {
self.as_str() == other.as_str()
}
}
impl Eq for PyUtf8Str {}
impl AnyStrContainer<str> for String {
fn new() -> Self {
Self::new()
@@ -2302,7 +2467,8 @@ impl std::fmt::Display for PyStrInterned {
impl AsRef<str> for PyStrInterned {
#[inline(always)]
fn as_ref(&self) -> &str {
self.as_str()
self.to_str()
.expect("Interned PyStr should always be valid UTF-8")
}
}

View File

@@ -78,12 +78,12 @@ where
#[inline]
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
let class = T::class(&vm.ctx);
let result = if obj.fast_isinstance(class) {
obj.downcast()
if obj.fast_isinstance(class) {
T::try_downcast_from(&obj, vm)?;
Ok(unsafe { obj.downcast_unchecked() })
} else {
Err(obj)
};
result.map_err(|obj| vm.new_downcast_type_error(class, &obj))
Err(vm.new_downcast_type_error(class, &obj))
}
}
}

View File

@@ -448,7 +448,7 @@ impl<T: PyObjectPayload> PyInner<T> {
let member_count = typ.slots.member_count;
Box::new(Self {
ref_count: RefCount::new(),
typeid: TypeId::of::<T>(),
typeid: T::payload_type_id(),
vtable: PyObjVTable::of::<T>(),
typ: PyAtomicRef::from(typ),
dict: dict.map(InstanceDict::new),
@@ -541,6 +541,11 @@ impl PyObjectRef {
}
}
pub fn try_downcast<T: PyObjectPayload>(self, vm: &VirtualMachine) -> PyResult<PyRef<T>> {
T::try_downcast_from(&self, vm)?;
Ok(unsafe { self.downcast_unchecked() })
}
/// Force to downcast this reference to a subclass.
///
/// # Safety
@@ -720,10 +725,24 @@ impl PyObject {
}
}
#[inline]
pub(crate) fn typeid(&self) -> TypeId {
self.0.typeid
}
/// Check if this object can be downcast to T.
#[inline(always)]
pub fn downcastable<T: PyObjectPayload>(&self) -> bool {
self.0.typeid == T::payload_type_id()
T::downcastable_from(self)
}
/// Attempt to downcast this reference to a subclass.
pub fn try_downcast_ref<'a, T: PyObjectPayload>(
&'a self,
vm: &VirtualMachine,
) -> PyResult<&'a Py<T>> {
T::try_downcast_from(self, vm)?;
Ok(unsafe { self.downcast_unchecked_ref::<T>() })
}
/// Attempt to downcast this reference to a subclass.

View File

@@ -1,6 +1,6 @@
use crate::object::{MaybeTraverse, Py, PyObjectRef, PyRef, PyResult};
use crate::{
PyRefExact,
PyObject, PyRefExact,
builtins::{PyBaseExceptionRef, PyType, PyTypeRef},
types::PyTypeFlags,
vm::{Context, VirtualMachine},
@@ -23,6 +23,31 @@ pub trait PyPayload:
fn payload_type_id() -> std::any::TypeId {
std::any::TypeId::of::<Self>()
}
/// # Safety: this function should only be called if `payload_type_id` matches the type of `obj`.
#[inline]
fn downcastable_from(obj: &PyObject) -> bool {
obj.typeid() == Self::payload_type_id()
}
fn try_downcast_from(obj: &PyObject, vm: &VirtualMachine) -> PyResult<()> {
if Self::downcastable_from(obj) {
return Ok(());
}
#[cold]
fn raise_downcast_type_error(
vm: &VirtualMachine,
class: &Py<PyType>,
obj: &PyObject,
) -> PyBaseExceptionRef {
vm.new_downcast_type_error(class, obj)
}
let class = Self::class(&vm.ctx);
Err(raise_downcast_type_error(vm, class, obj))
}
fn class(ctx: &Context) -> &'static Py<PyType>;
#[inline]