feature: PyTraverse derive macro for traverse object's childrens(like CPython's tp_traverse) (#4872)

This commit is contained in:
discord9
2023-04-23 20:43:49 +08:00
committed by GitHub
parent 2c90b128c3
commit 94bdb6b97a
59 changed files with 940 additions and 120 deletions

View File

@@ -18,6 +18,7 @@ mod pyclass;
mod pymodule;
mod pypayload;
mod pystructseq;
mod pytraverse;
use error::{extract_spans, Diagnostic};
use proc_macro2::TokenStream;
@@ -77,3 +78,7 @@ pub fn py_freeze(input: TokenStream, compiler: &dyn Compiler) -> TokenStream {
pub fn pypayload(input: DeriveInput) -> TokenStream {
result_to_tokens(pypayload::impl_pypayload(input))
}
pub fn pytraverse(item: DeriveInput) -> TokenStream {
result_to_tokens(pytraverse::impl_pytraverse(item))
}

View File

@@ -413,8 +413,59 @@ pub(crate) fn impl_pyclass(attr: AttributeArgs, item: Item) -> Result<TokenStrea
attrs,
)?;
const ALLOWED_TRAVERSE_OPTS: &[&str] = &["manual"];
// try to know if it have a `#[pyclass(trace)]` exist on this struct
// TODO(discord9): rethink on auto detect `#[Derive(PyTrace)]`
// 1. no `traverse` at all: generate a dummy try_traverse
// 2. `traverse = "manual"`: generate a try_traverse, but not #[derive(Traverse)]
// 3. `traverse`: generate a try_traverse, and #[derive(Traverse)]
let (maybe_trace_code, derive_trace) = {
if class_meta.inner()._has_key("traverse")? {
let maybe_trace_code = quote! {
impl ::rustpython_vm::object::MaybeTraverse for #ident {
const IS_TRACE: bool = true;
fn try_traverse(&self, tracer_fn: &mut ::rustpython_vm::object::TraverseFn) {
::rustpython_vm::object::Traverse::traverse(self, tracer_fn);
}
}
};
// if the key `traverse` exist but not as key-value, _optional_str return Err(...)
// so we need to check if it is Ok(Some(...))
let value = class_meta.inner()._optional_str("traverse");
let derive_trace = if let Ok(Some(s)) = value {
if !ALLOWED_TRAVERSE_OPTS.contains(&s.as_str()) {
bail_span!(
item,
"traverse attribute only accept {ALLOWED_TRAVERSE_OPTS:?} as value or no value at all",
);
}
assert_eq!(s, "manual");
quote! {}
} else {
quote! {#[derive(Traverse)]}
};
(maybe_trace_code, derive_trace)
} else {
(
// a dummy impl, which do nothing
// #attrs
quote! {
impl ::rustpython_vm::object::MaybeTraverse for #ident {
fn try_traverse(&self, tracer_fn: &mut ::rustpython_vm::object::TraverseFn) {
// do nothing
}
}
},
quote! {},
)
}
};
let ret = quote! {
#derive_trace
#item
#maybe_trace_code
#class_def
};
Ok(ret)

View File

@@ -0,0 +1,138 @@
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Attribute, DeriveInput, Field, Meta, MetaList, NestedMeta, Result};
struct TraverseAttr {
/// set to `true` if the attribute is `#[pytraverse(skip)]`
skip: bool,
}
const ATTR_TRAVERSE: &str = "pytraverse";
/// get the `#[pytraverse(..)]` attribute from the struct
fn valid_get_traverse_attr_from_meta_list(list: &MetaList) -> Result<TraverseAttr> {
let find_skip_and_only_skip = || {
let len = list.nested.len();
if len != 1 {
return None;
}
let mut iter = list.nested.iter();
// we have checked the length, so unwrap is safe
let first_arg = iter.next().unwrap();
let skip = match first_arg {
NestedMeta::Meta(Meta::Path(path)) => match path.is_ident("skip") {
true => true,
false => return None,
},
_ => return None,
};
Some(skip)
};
let skip = find_skip_and_only_skip().ok_or_else(|| {
err_span!(
list,
"only support attr is #[pytraverse(skip)], got arguments: {:?}",
list.nested
)
})?;
Ok(TraverseAttr { skip })
}
/// only accept `#[pytraverse(skip)]` for now
fn pytraverse_arg(attr: &Attribute) -> Option<Result<TraverseAttr>> {
if !attr.path.is_ident(ATTR_TRAVERSE) {
return None;
}
let ret = || {
let parsed = attr.parse_meta()?;
if let Meta::List(list) = parsed {
valid_get_traverse_attr_from_meta_list(&list)
} else {
bail_span!(attr, "pytraverse must be a list, like #[pytraverse(skip)]")
}
};
Some(ret())
}
fn field_to_traverse_code(field: &Field) -> Result<TokenStream> {
let pytraverse_attrs = field
.attrs
.iter()
.filter_map(pytraverse_arg)
.collect::<std::result::Result<Vec<_>, _>>()?;
let do_trace = if pytraverse_attrs.len() > 1 {
bail_span!(
field,
"found multiple #[pytraverse] attributes on the same field, expect at most one"
)
} else if pytraverse_attrs.is_empty() {
// default to always traverse every field
true
} else {
!pytraverse_attrs[0].skip
};
let name = field.ident.as_ref().ok_or_else(|| {
syn::Error::new_spanned(
field.clone(),
"Field should have a name in non-tuple struct",
)
})?;
if do_trace {
Ok(quote!(
::rustpython_vm::object::Traverse::traverse(&self.#name, tracer_fn);
))
} else {
Ok(quote!())
}
}
/// not trace corresponding field
fn gen_trace_code(item: &mut DeriveInput) -> Result<TokenStream> {
match &mut item.data {
syn::Data::Struct(s) => {
let fields = &mut s.fields;
match fields {
syn::Fields::Named(ref mut fields) => {
let res: Vec<TokenStream> = fields
.named
.iter_mut()
.map(|f| -> Result<TokenStream> { field_to_traverse_code(f) })
.collect::<Result<_>>()?;
let res = res.into_iter().collect::<TokenStream>();
Ok(res)
}
syn::Fields::Unnamed(fields) => {
let res: TokenStream = (0..fields.unnamed.len())
.map(|i| {
let i = syn::Index::from(i);
quote!(
::rustpython_vm::object::Traverse::traverse(&self.#i, tracer_fn);
)
})
.collect();
Ok(res)
}
_ => Err(syn::Error::new_spanned(
fields,
"Only named and unnamed fields are supported",
)),
}
}
_ => Err(syn::Error::new_spanned(item, "Only structs are supported")),
}
}
pub(crate) fn impl_pytraverse(mut item: DeriveInput) -> Result<TokenStream> {
let trace_code = gen_trace_code(&mut item)?;
let ty = &item.ident;
let ret = quote! {
unsafe impl ::rustpython_vm::object::Traverse for #ty {
fn traverse(&self, tracer_fn: &mut ::rustpython_vm::object::TraverseFn) {
#trace_code
}
}
};
Ok(ret)
}

View File

@@ -178,6 +178,10 @@ impl ItemMetaInner {
Ok(value)
}
pub fn _has_key(&self, key: &str) -> Result<bool> {
Ok(matches!(self.meta_map.get(key), Some((_, _))))
}
pub fn _bool(&self, key: &str) -> Result<bool> {
let value = if let Some((_, meta)) = self.meta_map.get(key) {
match meta {
@@ -263,8 +267,14 @@ impl ItemMeta for AttrItemMeta {
pub(crate) struct ClassItemMeta(ItemMetaInner);
impl ItemMeta for ClassItemMeta {
const ALLOWED_NAMES: &'static [&'static str] =
&["module", "name", "base", "metaclass", "unhashable"];
const ALLOWED_NAMES: &'static [&'static str] = &[
"module",
"name",
"base",
"metaclass",
"unhashable",
"traverse",
];
fn from_inner(inner: ItemMetaInner) -> Self {
Self(inner)

View File

@@ -91,3 +91,27 @@ pub fn pypayload(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input);
derive_impl::pypayload(input).into()
}
/// use on struct with named fields like `struct A{x:PyRef<B>, y:PyRef<C>}` to impl `Traverse` for datatype.
///
/// use `#[pytraverse(skip)]` on fields you wish not to trace
///
/// add `trace` attr to `#[pyclass]` to make it impl `MaybeTraverse` that will call `Traverse`'s `traverse` method so make it
/// traceable(Even from type-erased PyObject)(i.e. write `#[pyclass(trace)]`).
/// # Example
/// ```rust, ignore
/// #[pyclass(module = false, traverse)]
/// #[derive(Default, Traverse)]
/// pub struct PyList {
/// elements: PyRwLock<Vec<PyObjectRef>>,
/// #[pytraverse(skip)]
/// len: AtomicCell<usize>,
/// }
/// ```
/// This create both `MaybeTraverse` that call `Traverse`'s `traverse` method and `Traverse` that impl `Traverse`
/// for `PyList` which call elements' `traverse` method and ignore `len` field.
#[proc_macro_derive(Traverse, attributes(pytraverse))]
pub fn pytraverse(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
let item = parse_macro_input!(item);
derive_impl::pytraverse(item).into()
}

View File

@@ -1399,7 +1399,7 @@ mod array {
}
#[pyattr]
#[pyclass(name = "arrayiterator")]
#[pyclass(name = "arrayiterator", traverse)]
#[derive(Debug, PyPayload)]
pub struct PyArrayIter {
internal: PyMutex<PositionIterInternal<PyArrayRef>>,
@@ -1434,12 +1434,13 @@ mod array {
}
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
struct ReconstructorArgs {
#[pyarg(positional)]
arraytype: PyTypeRef,
#[pyarg(positional)]
typecode: PyStrRef,
#[pytraverse(skip)]
#[pyarg(positional)]
mformat_code: MachineFormatCode,
#[pyarg(positional)]

View File

@@ -8,7 +8,7 @@ mod _bisect {
PyObjectRef, PyResult, VirtualMachine,
};
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
struct BisectArgs {
a: PyObjectRef,
x: PyObjectRef,

View File

@@ -80,9 +80,10 @@ mod _contextvars {
}
#[pyattr]
#[pyclass(name)]
#[pyclass(name, traverse)]
#[derive(Debug, PyPayload)]
struct ContextVar {
#[pytraverse(skip)]
#[allow(dead_code)] // TODO: RUSTPYTHON
name: String,
#[allow(dead_code)] // TODO: RUSTPYTHON
@@ -161,7 +162,7 @@ mod _contextvars {
#[derive(Debug, PyPayload)]
struct ContextToken {}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
struct ContextTokenOptions {
#[pyarg(positional)]
#[allow(dead_code)] // TODO: RUSTPYTHON

View File

@@ -152,10 +152,11 @@ mod _csv {
reader: csv_core::Reader,
}
#[pyclass(no_attr, module = "_csv", name = "reader")]
#[pyclass(no_attr, module = "_csv", name = "reader", traverse)]
#[derive(PyPayload)]
pub(super) struct Reader {
iter: PyIter,
#[pytraverse(skip)]
state: PyMutex<ReadState>,
}
@@ -242,10 +243,11 @@ mod _csv {
writer: csv_core::Writer,
}
#[pyclass(no_attr, module = "_csv", name = "writer")]
#[pyclass(no_attr, module = "_csv", name = "writer", traverse)]
#[derive(PyPayload)]
pub(super) struct Writer {
write: PyObjectRef,
#[pytraverse(skip)]
state: PyMutex<WriteState>,
}

View File

@@ -13,11 +13,14 @@ mod grp {
use std::ptr::NonNull;
#[pyattr]
#[pyclass(module = "grp", name = "struct_group")]
#[pyclass(module = "grp", name = "struct_group", traverse)]
#[derive(PyStructSequence)]
struct Group {
#[pytraverse(skip)]
gr_name: String,
#[pytraverse(skip)]
gr_passwd: String,
#[pytraverse(skip)]
gr_gid: u32,
gr_mem: PyListRef,
}

View File

@@ -19,22 +19,24 @@ mod hashlib {
use sha2::{Sha224, Sha256, Sha384, Sha512};
use sha3::{Sha3_224, Sha3_256, Sha3_384, Sha3_512, Shake128, Shake256};
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
#[allow(unused)]
struct NewHashArgs {
#[pyarg(positional)]
name: PyStrRef,
#[pyarg(any, optional)]
data: OptionalArg<ArgBytesLike>,
#[pytraverse(skip)]
#[pyarg(named, default = "true")]
usedforsecurity: bool,
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
#[allow(unused)]
struct BlakeHashArgs {
#[pyarg(positional, optional)]
data: OptionalArg<ArgBytesLike>,
#[pytraverse(skip)]
#[pyarg(named, default = "true")]
usedforsecurity: bool,
}
@@ -48,11 +50,12 @@ mod hashlib {
}
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
#[allow(unused)]
struct HashArgs {
#[pyarg(any, optional)]
string: OptionalArg<ArgBytesLike>,
#[pytraverse(skip)]
#[pyarg(named, default = "true")]
usedforsecurity: bool,
}

View File

@@ -16,9 +16,10 @@ mod _json {
use std::str::FromStr;
#[pyattr(name = "make_scanner")]
#[pyclass(name = "Scanner")]
#[pyclass(name = "Scanner", traverse)]
#[derive(Debug, PyPayload)]
struct JsonScanner {
#[pytraverse(skip)]
strict: bool,
object_hook: Option<PyObjectRef>,
object_pairs_hook: Option<PyObjectRef>,

View File

@@ -885,7 +885,7 @@ mod math {
}
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
struct ProdArgs {
#[pyarg(positional)]
iterable: ArgIterable<PyObjectRef>,

View File

@@ -244,8 +244,9 @@ mod mmap {
}
#[cfg(not(target_os = "redox"))]
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
pub struct AdviseOptions {
#[pytraverse(skip)]
#[pyarg(positional)]
option: libc::c_int,
#[pyarg(positional, default)]

View File

@@ -43,7 +43,7 @@ mod _pyexpat {
type MutableObject = PyRwLock<PyObjectRef>;
#[pyattr]
#[pyclass(name = "xmlparser", module = false)]
#[pyclass(name = "xmlparser", module = false, traverse)]
#[derive(Debug, PyPayload)]
pub struct PyExpatLikeXmlParser {
start_element: MutableObject,
@@ -156,7 +156,7 @@ mod _pyexpat {
}
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
#[allow(dead_code)]
struct ParserCreateArgs {
#[pyarg(any, optional)]

View File

@@ -20,6 +20,7 @@ pub(crate) mod _struct {
};
use crossbeam_utils::atomic::AtomicCell;
#[derive(Traverse)]
struct IntoStructFormatBytes(PyStrRef);
impl TryFromObject for IntoStructFormatBytes {
@@ -133,9 +134,10 @@ pub(crate) mod _struct {
buffer.with_ref(|buf| format_spec.unpack(buf, vm))
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
struct UpdateFromArgs {
buffer: ArgBytesLike,
#[pytraverse(skip)]
#[pyarg(any, default = "0")]
offset: isize,
}
@@ -154,11 +156,13 @@ pub(crate) mod _struct {
}
#[pyattr]
#[pyclass(name = "unpack_iterator")]
#[pyclass(name = "unpack_iterator", traverse)]
#[derive(Debug, PyPayload)]
struct UnpackIterator {
#[pytraverse(skip)]
format_spec: FormatSpec,
buffer: ArgBytesLike,
#[pytraverse(skip)]
offset: AtomicCell<usize>,
}
@@ -231,9 +235,10 @@ pub(crate) mod _struct {
}
#[pyattr]
#[pyclass(name = "Struct")]
#[pyclass(name = "Struct", traverse)]
#[derive(Debug, PyPayload)]
struct PyStruct {
#[pytraverse(skip)]
spec: FormatSpec,
format: PyStrRef,
}

View File

@@ -76,10 +76,11 @@ mod re {
/// Inner data for a match object.
#[pyattr]
#[pyclass(module = "re", name = "Match")]
#[derive(PyPayload)]
#[pyclass(module = "re", name = "Match", traverse)]
#[derive(PyPayload, Traverse)]
struct PyMatch {
haystack: PyStrRef,
#[pytraverse(skip)]
captures: Vec<Option<Range<usize>>>,
}

View File

@@ -67,8 +67,10 @@ mod platform {
pub use platform::timeval;
use platform::RawFd;
#[derive(Traverse)]
struct Selectable {
obj: PyObjectRef,
#[pytraverse(skip)]
fno: RawFd,
}

View File

@@ -69,6 +69,7 @@ mod _sqlite {
AsObject, Py, PyAtomicRef, PyObject, PyObjectRef, PyPayload, PyRef, PyResult,
TryFromBorrowedObject, VirtualMachine,
__exports::paste,
object::{Traverse, TraverseFn},
};
use std::{
ffi::{c_int, c_longlong, c_uint, c_void, CStr},
@@ -311,6 +312,13 @@ mod _sqlite {
uri: bool,
}
unsafe impl Traverse for ConnectArgs {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.isolation_level.traverse(tracer_fn);
self.factory.traverse(tracer_fn);
}
}
#[derive(FromArgs)]
struct BackupArgs {
#[pyarg(any)]
@@ -325,36 +333,48 @@ mod _sqlite {
sleep: f64,
}
#[derive(FromArgs)]
unsafe impl Traverse for BackupArgs {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.progress.traverse(tracer_fn);
self.name.traverse(tracer_fn);
}
}
#[derive(FromArgs, Traverse)]
struct CreateFunctionArgs {
#[pyarg(any)]
name: PyStrRef,
#[pytraverse(skip)]
#[pyarg(any)]
narg: c_int,
#[pyarg(any)]
func: PyObjectRef,
#[pytraverse(skip)]
#[pyarg(named, default)]
deterministic: bool,
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
struct CreateAggregateArgs {
#[pyarg(any)]
name: PyStrRef,
#[pytraverse(skip)]
#[pyarg(positional)]
narg: c_int,
#[pyarg(positional)]
aggregate_class: PyObjectRef,
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
struct BlobOpenArgs {
#[pyarg(positional)]
table: PyStrRef,
#[pyarg(positional)]
column: PyStrRef,
#[pytraverse(skip)]
#[pyarg(positional)]
row: i64,
#[pytraverse(skip)]
#[pyarg(named, default)]
readonly: bool,
#[pyarg(named, default = "vm.ctx.new_str(stringify!(main))")]
@@ -1340,20 +1360,24 @@ mod _sqlite {
}
#[pyattr]
#[pyclass(name)]
#[pyclass(name, traverse)]
#[derive(Debug, PyPayload)]
struct Cursor {
connection: PyRef<Connection>,
#[pytraverse(skip)]
arraysize: PyAtomic<c_int>,
#[pytraverse(skip)]
row_factory: PyAtomicRef<Option<PyObject>>,
inner: PyMutex<Option<CursorInner>>,
}
#[derive(Debug)]
#[derive(Debug, Traverse)]
struct CursorInner {
description: Option<PyTupleRef>,
row_cast_map: Vec<Option<PyObjectRef>>,
#[pytraverse(skip)]
lastrowid: i64,
#[pytraverse(skip)]
rowcount: i64,
statement: Option<PyRef<Statement>>,
}
@@ -1793,7 +1817,7 @@ mod _sqlite {
}
#[pyattr]
#[pyclass(name)]
#[pyclass(name, traverse)]
#[derive(Debug, PyPayload)]
struct Row {
data: PyTupleRef,
@@ -1922,10 +1946,11 @@ mod _sqlite {
}
#[pyattr]
#[pyclass(name)]
#[pyclass(name, traverse)]
#[derive(Debug, PyPayload)]
struct Blob {
connection: PyRef<Connection>,
#[pytraverse(skip)]
inner: PyMutex<Option<BlobInner>>,
}

View File

@@ -797,9 +797,10 @@ mod _ssl {
}
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
struct WrapSocketArgs {
sock: PyRef<PySocket>,
#[pytraverse(skip)]
server_side: bool,
#[pyarg(any, default)]
server_hostname: Option<PyStrRef>,
@@ -819,9 +820,11 @@ mod _ssl {
cadata: Option<Either<PyStrRef, ArgBytesLike>>,
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
struct LoadCertChainArgs {
#[pytraverse(skip)]
certfile: FsPath,
#[pytraverse(skip)]
#[pyarg(any, optional)]
keyfile: Option<FsPath>,
#[pyarg(any, optional)]
@@ -902,11 +905,13 @@ mod _ssl {
}
#[pyattr]
#[pyclass(module = "ssl", name = "_SSLSocket")]
#[pyclass(module = "ssl", name = "_SSLSocket", traverse)]
#[derive(PyPayload)]
struct PySslSocket {
ctx: PyRef<PySslContext>,
#[pytraverse(skip)]
stream: PyRwLock<ssl::SslStream<SocketStream>>,
#[pytraverse(skip)]
socket_type: SslServerOrClient,
server_hostname: Option<PyStrRef>,
owner: PyRwLock<Option<PyRef<PyWeak>>>,

View File

@@ -95,7 +95,7 @@ mod syslog {
Ok(())
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
struct SysLogArgs {
#[pyarg(positional)]
priority: PyObjectRef,

View File

@@ -32,7 +32,7 @@ use std::fmt;
pub type DictContentType = dictdatatype::Dict;
#[pyclass(module = false, name = "dict", unhashable = true)]
#[pyclass(module = false, name = "dict", unhashable = true, traverse)]
#[derive(Default)]
pub struct PyDict {
entries: DictContentType,

View File

@@ -13,9 +13,10 @@ use crate::{
use num_bigint::BigInt;
use num_traits::Zero;
#[pyclass(module = false, name = "enumerate")]
#[pyclass(module = false, name = "enumerate", traverse)]
#[derive(Debug)]
pub struct PyEnumerate {
#[pytraverse(skip)]
counter: PyRwLock<BigInt>,
iterator: PyIter,
}
@@ -84,7 +85,7 @@ impl IterNext for PyEnumerate {
}
}
#[pyclass(module = false, name = "reversed")]
#[pyclass(module = false, name = "reversed", traverse)]
#[derive(Debug)]
pub struct PyReverseSequenceIterator {
internal: PyMutex<PositionIterInternal<PyObjectRef>>,

View File

@@ -6,7 +6,7 @@ use crate::{
Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
};
#[pyclass(module = false, name = "filter")]
#[pyclass(module = false, name = "filter", traverse)]
#[derive(Debug)]
pub struct PyFilter {
predicate: PyObjectRef,

View File

@@ -10,6 +10,7 @@ use crate::common::lock::OnceCell;
use crate::common::lock::PyMutex;
use crate::convert::ToPyObject;
use crate::function::ArgMapping;
use crate::object::{Traverse, TraverseFn};
use crate::{
bytecode,
class::PyClassImpl,
@@ -25,7 +26,7 @@ use itertools::Itertools;
#[cfg(feature = "jit")]
use rustpython_jit::CompiledCode;
#[pyclass(module = false, name = "function")]
#[pyclass(module = false, name = "function", traverse = "manual")]
#[derive(Debug)]
pub struct PyFunction {
code: PyRef<PyCode>,
@@ -38,6 +39,14 @@ pub struct PyFunction {
jitted_code: OnceCell<CompiledCode>,
}
unsafe impl Traverse for PyFunction {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.globals.traverse(tracer_fn);
self.closure.traverse(tracer_fn);
self.defaults_and_kwdefaults.traverse(tracer_fn);
}
}
impl PyFunction {
pub(crate) fn new(
code: PyRef<PyCode>,
@@ -468,7 +477,7 @@ impl Representable for PyFunction {
}
}
#[pyclass(module = false, name = "method")]
#[pyclass(module = false, name = "method", traverse)]
#[derive(Debug)]
pub struct PyBoundMethod {
object: PyObjectRef,
@@ -513,7 +522,7 @@ impl GetAttr for PyBoundMethod {
}
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
pub struct PyBoundMethodNewArgs {
#[pyarg(positional)]
function: PyObjectRef,
@@ -633,7 +642,7 @@ impl Representable for PyBoundMethod {
}
}
#[pyclass(module = false, name = "cell")]
#[pyclass(module = false, name = "cell", traverse)]
#[derive(Debug, Default)]
pub(crate) struct PyCell {
contents: PyMutex<Option<PyObjectRef>>,

View File

@@ -6,6 +6,7 @@ use super::{PyInt, PyTupleRef, PyType};
use crate::{
class::PyClassImpl,
function::ArgCallable,
object::{Traverse, TraverseFn},
protocol::{PyIterReturn, PySequence, PySequenceMethods},
types::{IterNext, IterNextIterable},
Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
@@ -24,12 +25,27 @@ pub enum IterStatus<T> {
Exhausted,
}
unsafe impl<T: Traverse> Traverse for IterStatus<T> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
match self {
IterStatus::Active(ref r) => r.traverse(tracer_fn),
IterStatus::Exhausted => (),
}
}
}
#[derive(Debug)]
pub struct PositionIterInternal<T> {
pub status: IterStatus<T>,
pub position: usize,
}
unsafe impl<T: Traverse> Traverse for PositionIterInternal<T> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.status.traverse(tracer_fn)
}
}
impl<T> PositionIterInternal<T> {
pub fn new(obj: T, position: usize) -> Self {
Self {
@@ -158,10 +174,11 @@ pub fn builtins_reversed(vm: &VirtualMachine) -> &PyObject {
INSTANCE.get_or_init(|| vm.builtins.get_attr("reversed", vm).unwrap())
}
#[pyclass(module = false, name = "iterator")]
#[pyclass(module = false, name = "iterator", traverse)]
#[derive(Debug)]
pub struct PySequenceIterator {
// cached sequence methods
#[pytraverse(skip)]
seq_methods: &'static PySequenceMethods,
internal: PyMutex<PositionIterInternal<PyObjectRef>>,
}
@@ -222,7 +239,7 @@ impl IterNext for PySequenceIterator {
}
}
#[pyclass(module = false, name = "callable_iterator")]
#[pyclass(module = false, name = "callable_iterator", traverse)]
#[derive(Debug)]
pub struct PyCallableIterator {
sentinel: PyObjectRef,

View File

@@ -22,7 +22,7 @@ use crate::{
};
use std::{fmt, ops::DerefMut};
#[pyclass(module = false, name = "list", unhashable = true)]
#[pyclass(module = false, name = "list", unhashable = true, traverse)]
#[derive(Default)]
pub struct PyList {
elements: PyRwLock<Vec<PyObjectRef>>,
@@ -86,10 +86,11 @@ impl PyList {
}
}
#[derive(FromArgs, Default)]
#[derive(FromArgs, Default, Traverse)]
pub(crate) struct SortOptions {
#[pyarg(named, default)]
key: Option<PyObjectRef>,
#[pytraverse(skip)]
#[pyarg(named, default = "false")]
reverse: bool,
}
@@ -530,7 +531,7 @@ fn do_sort(
Ok(())
}
#[pyclass(module = false, name = "list_iterator")]
#[pyclass(module = false, name = "list_iterator", traverse)]
#[derive(Debug)]
pub struct PyListIterator {
internal: PyMutex<PositionIterInternal<PyListRef>>,
@@ -575,7 +576,7 @@ impl IterNext for PyListIterator {
}
}
#[pyclass(module = false, name = "list_reverseiterator")]
#[pyclass(module = false, name = "list_reverseiterator", traverse)]
#[derive(Debug)]
pub struct PyListReverseIterator {
internal: PyMutex<PositionIterInternal<PyListRef>>,

View File

@@ -8,7 +8,7 @@ use crate::{
Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
};
#[pyclass(module = false, name = "map")]
#[pyclass(module = false, name = "map", traverse)]
#[derive(Debug)]
pub struct PyMap {
mapper: PyObjectRef,

View File

@@ -4,6 +4,7 @@ use crate::{
class::PyClassImpl,
convert::ToPyObject,
function::{ArgMapping, OptionalArg, PyComparisonValue},
object::{Traverse, TraverseFn},
protocol::{PyMapping, PyMappingMethods, PyNumberMethods, PySequenceMethods},
types::{
AsMapping, AsNumber, AsSequence, Comparable, Constructor, Iterable, PyComparisonOp,
@@ -13,7 +14,7 @@ use crate::{
};
use once_cell::sync::Lazy;
#[pyclass(module = false, name = "mappingproxy")]
#[pyclass(module = false, name = "mappingproxy", traverse)]
#[derive(Debug)]
pub struct PyMappingProxy {
mapping: MappingProxyInner,
@@ -25,6 +26,15 @@ enum MappingProxyInner {
Mapping(ArgMapping),
}
unsafe impl Traverse for MappingProxyInner {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
match self {
MappingProxyInner::Class(ref r) => r.traverse(tracer_fn),
MappingProxyInner::Mapping(ref arg) => arg.traverse(tracer_fn),
}
}
}
impl PyPayload for PyMappingProxy {
fn class(ctx: &Context) -> &'static Py<PyType> {
ctx.types.mappingproxy_type

View File

@@ -33,7 +33,7 @@ use once_cell::sync::Lazy;
use rustpython_common::lock::PyMutex;
use std::{cmp::Ordering, fmt::Debug, mem::ManuallyDrop, ops::Range};
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
pub struct PyMemoryViewNewArgs {
object: PyObjectRef,
}
@@ -896,7 +896,7 @@ impl Py<PyMemoryView> {
}
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
struct CastArgs {
#[pyarg(any)]
format: PyStrRef,
@@ -1126,7 +1126,7 @@ impl Iterable for PyMemoryView {
}
#[pyclass(module = false, name = "memory_iterator")]
#[derive(Debug)]
#[derive(Debug, Traverse)]
pub struct PyMemoryViewIterator {
internal: PyMutex<PositionIterInternal<PyRef<PyMemoryView>>>,
}

View File

@@ -11,7 +11,7 @@ use crate::{
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
};
#[pyclass(module = false, name = "property")]
#[pyclass(module = false, name = "property", traverse)]
#[derive(Debug)]
pub struct PyProperty {
getter: PyRwLock<Option<PyObjectRef>>,

View File

@@ -28,7 +28,7 @@ use std::{fmt, ops::Deref};
pub type SetContentType = dictdatatype::Dict<()>;
#[pyclass(module = false, name = "set", unhashable = true)]
#[pyclass(module = false, name = "set", unhashable = true, traverse)]
#[derive(Default)]
pub struct PySet {
pub(super) inner: PySetInner,
@@ -151,6 +151,13 @@ pub(super) struct PySetInner {
content: PyRc<SetContentType>,
}
unsafe impl crate::object::Traverse for PySetInner {
fn traverse(&self, tracer_fn: &mut crate::object::TraverseFn) {
// FIXME(discord9): Rc means shared ref, so should it be traced?
self.content.traverse(tracer_fn)
}
}
impl PySetInner {
pub(super) fn from_iter<T>(iter: T, vm: &VirtualMachine) -> PyResult<Self>
where

View File

@@ -12,7 +12,7 @@ use crate::{
use num_bigint::{BigInt, ToBigInt};
use num_traits::{One, Signed, Zero};
#[pyclass(module = false, name = "slice", unhashable = true)]
#[pyclass(module = false, name = "slice", unhashable = true, traverse)]
#[derive(Debug)]
pub struct PySlice {
pub start: Option<PyObjectRef>,

View File

@@ -8,7 +8,7 @@ use crate::{
Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
};
#[pyclass(module = false, name = "staticmethod")]
#[pyclass(module = false, name = "staticmethod", traverse)]
#[derive(Debug)]
pub struct PyStaticMethod {
pub callable: PyMutex<PyObjectRef>,

View File

@@ -15,6 +15,7 @@ use crate::{
format::{format, format_map},
function::{ArgIterable, ArgSize, FuncArgs, OptionalArg, OptionalOption, PyComparisonValue},
intern::PyInterned,
object::{Traverse, TraverseFn},
protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
sequence::SequenceExt,
sliceable::{SequenceIndex, SliceableSequenceOp},
@@ -186,12 +187,19 @@ impl<'a> AsPyStr<'a> for &'a PyStrInterned {
}
}
#[pyclass(module = false, name = "str_iterator")]
#[pyclass(module = false, name = "str_iterator", traverse = "manual")]
#[derive(Debug)]
pub struct PyStrIterator {
internal: PyMutex<(PositionIterInternal<PyStrRef>, usize)>,
}
unsafe impl Traverse for PyStrIterator {
fn traverse(&self, tracer: &mut TraverseFn) {
// No need to worry about deadlock, for inner is a PyStr and can't make ref cycle
self.internal.lock().0.traverse(tracer);
}
}
impl PyPayload for PyStrIterator {
fn class(ctx: &Context) -> &'static Py<PyType> {
ctx.types.str_iterator_type
@@ -251,7 +259,7 @@ impl IterNext for PyStrIterator {
}
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
pub struct StrArgs {
#[pyarg(any, optional)]
object: OptionalArg<PyObjectRef>,
@@ -1392,7 +1400,7 @@ impl AsSequence for PyStr {
}
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
struct EncodeArgs {
#[pyarg(any, default)]
encoding: Option<PyStrRef>,
@@ -1456,7 +1464,7 @@ impl ToPyObject for AsciiString {
type SplitArgs = anystr::SplitArgs<PyStrRef>;
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
pub struct FindArgs {
#[pyarg(positional)]
sub: PyStrRef,

View File

@@ -11,7 +11,7 @@ use crate::{
AsObject, Context, Py, PyObjectRef, PyPayload, PyResult, VirtualMachine,
};
#[pyclass(module = false, name = "super")]
#[pyclass(module = false, name = "super", traverse)]
#[derive(Debug)]
pub struct PySuper {
typ: PyTypeRef,
@@ -24,7 +24,7 @@ impl PyPayload for PySuper {
}
}
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
pub struct PySuperNewArgs {
#[pyarg(positional, optional)]
py_type: OptionalArg<PyTypeRef>,

View File

@@ -3,12 +3,14 @@ use rustpython_common::lock::PyMutex;
use super::PyType;
use crate::{class::PyClassImpl, frame::FrameRef, Context, Py, PyPayload, PyRef};
#[pyclass(module = false, name = "traceback")]
#[pyclass(module = false, name = "traceback", traverse)]
#[derive(Debug)]
pub struct PyTraceback {
pub next: PyMutex<Option<PyTracebackRef>>,
pub frame: FrameRef,
#[pytraverse(skip)]
pub lasti: u32,
#[pytraverse(skip)]
pub lineno: usize,
}

View File

@@ -1,5 +1,6 @@
use super::{PositionIterInternal, PyGenericAlias, PyStrRef, PyType, PyTypeRef};
use crate::common::{hash::PyHash, lock::PyMutex};
use crate::object::{Traverse, TraverseFn};
use crate::{
atomic_func,
class::PyClassImpl,
@@ -21,7 +22,7 @@ use crate::{
use once_cell::sync::Lazy;
use std::{fmt, marker::PhantomData};
#[pyclass(module = false, name = "tuple")]
#[pyclass(module = false, name = "tuple", traverse)]
pub struct PyTuple {
elements: Box<[PyObjectRef]>,
}
@@ -421,7 +422,7 @@ impl Representable for PyTuple {
}
}
#[pyclass(module = false, name = "tuple_iterator")]
#[pyclass(module = false, name = "tuple_iterator", traverse)]
#[derive(Debug)]
pub(crate) struct PyTupleIterator {
internal: PyMutex<PositionIterInternal<PyTupleRef>>,
@@ -479,6 +480,15 @@ pub struct PyTupleTyped<T: TransmuteFromObject> {
_marker: PhantomData<Vec<T>>,
}
unsafe impl<T> Traverse for PyTupleTyped<T>
where
T: TransmuteFromObject + Traverse,
{
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.tuple.traverse(tracer_fn);
}
}
impl<T: TransmuteFromObject> TryFromObject for PyTupleTyped<T> {
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
let tuple = PyTupleRef::try_from_object(vm, obj)?;

View File

@@ -20,6 +20,7 @@ use crate::{
convert::ToPyResult,
function::{FuncArgs, KwArgs, OptionalArg, PySetterValue},
identifier,
object::{Traverse, TraverseFn},
protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
types::{AsNumber, Callable, GetAttr, PyTypeFlags, PyTypeSlots, Representable, SetAttr},
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject,
@@ -29,7 +30,7 @@ use indexmap::{map::Entry, IndexMap};
use itertools::Itertools;
use std::{borrow::Borrow, collections::HashSet, fmt, ops::Deref, pin::Pin, ptr::NonNull};
#[pyclass(module = false, name = "type")]
#[pyclass(module = false, name = "type", traverse = "manual")]
pub struct PyType {
pub base: Option<PyTypeRef>,
pub bases: Vec<PyTypeRef>,
@@ -40,6 +41,20 @@ pub struct PyType {
pub heaptype_ext: Option<Pin<Box<HeapTypeExt>>>,
}
unsafe impl crate::object::Traverse for PyType {
fn traverse(&self, tracer_fn: &mut crate::object::TraverseFn) {
self.base.traverse(tracer_fn);
self.bases.traverse(tracer_fn);
self.mro.traverse(tracer_fn);
self.subclasses.traverse(tracer_fn);
self.attributes
.read_recursive()
.iter()
.map(|(_, v)| v.traverse(tracer_fn))
.count();
}
}
pub struct HeapTypeExt {
pub name: PyRwLock<PyStrRef>,
pub slots: Option<PyTupleTyped<PyStrRef>>,
@@ -100,6 +115,12 @@ cfg_if::cfg_if! {
/// faster and only supports strings as keys.
pub type PyAttributes = IndexMap<&'static PyStrInterned, PyObjectRef, ahash::RandomState>;
unsafe impl Traverse for PyAttributes {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.values().for_each(|v| v.traverse(tracer_fn));
}
}
impl fmt::Display for PyType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.name(), f)

View File

@@ -15,7 +15,7 @@ use std::fmt;
const CLS_ATTRS: &[&str] = &["__module__"];
#[pyclass(module = "types", name = "UnionType")]
#[pyclass(module = "types", name = "UnionType", traverse)]
pub struct PyUnion {
args: PyTupleRef,
parameters: PyTupleRef,

View File

@@ -13,7 +13,7 @@ use crate::{
};
use once_cell::sync::Lazy;
#[pyclass(module = false, name = "weakproxy", unhashable = true)]
#[pyclass(module = false, name = "weakproxy", unhashable = true, traverse)]
#[derive(Debug)]
pub struct PyWeakProxy {
weak: PyRef<PyWeak>,

View File

@@ -12,7 +12,7 @@ use crate::{
pub use crate::object::PyWeak;
#[derive(FromArgs)]
#[derive(FromArgs, Traverse)]
pub struct WeakNewArgs {
#[pyarg(positional)]
referent: PyObjectRef,

View File

@@ -9,10 +9,11 @@ use crate::{
};
use rustpython_common::atomic::{self, PyAtomic, Radium};
#[pyclass(module = false, name = "zip")]
#[pyclass(module = false, name = "zip", traverse)]
#[derive(Debug)]
pub struct PyZip {
iterators: Vec<PyIter>,
#[pytraverse(skip)]
strict: PyAtomic<bool>,
}

View File

@@ -3,15 +3,18 @@
//! And: https://www.youtube.com/watch?v=p33CVV29OG8
//! And: http://code.activestate.com/recipes/578375/
use crate::common::{
hash,
lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard},
};
use crate::{
builtins::{PyInt, PyStr, PyStrInterned, PyStrRef},
convert::ToPyObject,
AsObject, Py, PyExact, PyObject, PyObjectRef, PyRefExact, PyResult, VirtualMachine,
};
use crate::{
common::{
hash,
lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard},
},
object::{Traverse, TraverseFn},
};
use num_traits::ToPrimitive;
use std::{fmt, mem::size_of, ops::ControlFlow};
@@ -31,6 +34,12 @@ pub struct Dict<T = PyObjectRef> {
inner: PyRwLock<DictInner<T>>,
}
unsafe impl<T: Traverse> Traverse for Dict<T> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.inner.traverse(tracer_fn);
}
}
impl<T> fmt::Debug for Dict<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Debug").finish()
@@ -69,6 +78,20 @@ struct DictInner<T> {
entries: Vec<Option<DictEntry<T>>>,
}
unsafe impl<T: Traverse> Traverse for DictInner<T> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.entries
.iter()
.map(|v| {
if let Some(v) = v {
v.key.traverse(tracer_fn);
v.value.traverse(tracer_fn);
}
})
.count();
}
}
impl<T: Clone> Clone for Dict<T> {
fn clone(&self) -> Self {
Self {

View File

@@ -1,5 +1,6 @@
use self::types::{PyBaseException, PyBaseExceptionRef};
use crate::common::{lock::PyRwLock, str::ReprOverflowError};
use crate::object::{Traverse, TraverseFn};
use crate::{
builtins::{
traceback::PyTracebackRef, tuple::IntoPyTuple, PyNone, PyStr, PyStrRef, PyTuple,
@@ -21,6 +22,15 @@ use std::{
io::{self, BufRead, BufReader},
};
unsafe impl Traverse for PyBaseException {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.traceback.traverse(tracer_fn);
self.cause.traverse(tracer_fn);
self.context.traverse(tracer_fn);
self.args.traverse(tracer_fn);
}
}
impl std::fmt::Debug for PyBaseException {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
// TODO: implement more detailed, non-recursive Debug formatter
@@ -1143,7 +1153,7 @@ pub(super) mod types {
// Sorted By Hierarchy then alphabetized.
#[pyclass(module = false, name = "BaseException")]
#[pyclass(module = false, name = "BaseException", traverse = "manual")]
pub struct PyBaseException {
pub(super) traceback: PyRwLock<Option<PyTracebackRef>>,
pub(super) cause: PyRwLock<Option<PyRef<Self>>>,

View File

@@ -1,6 +1,7 @@
use crate::{
builtins::{PyBaseExceptionRef, PyTupleRef, PyTypeRef},
convert::ToPyObject,
object::{Traverse, TraverseFn},
AsObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine,
};
use indexmap::IndexMap;
@@ -57,13 +58,19 @@ into_func_args_from_tuple!((v1, T1), (v2, T2), (v3, T3), (v4, T4), (v5, T5));
/// The `FuncArgs` struct is one of the most used structs then creating
/// a rust function that can be called from python. It holds both positional
/// arguments, as well as keyword arguments passed to the function.
#[derive(Debug, Default, Clone)]
#[derive(Debug, Default, Clone, Traverse)]
pub struct FuncArgs {
pub args: Vec<PyObjectRef>,
// sorted map, according to https://www.python.org/dev/peps/pep-0468/
pub kwargs: IndexMap<String, PyObjectRef>,
}
unsafe impl Traverse for IndexMap<String, PyObjectRef> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.values().for_each(|v| v.traverse(tracer_fn));
}
}
/// Conversion from vector of python objects to function arguments.
impl<A> From<A> for FuncArgs
where
@@ -320,6 +327,15 @@ impl<T: TryFromObject> FromArgOptional for T {
#[derive(Clone)]
pub struct KwArgs<T = PyObjectRef>(IndexMap<String, T>);
unsafe impl<T> Traverse for KwArgs<T>
where
T: Traverse,
{
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.0.iter().map(|(_, v)| v.traverse(tracer_fn)).count();
}
}
impl<T> KwArgs<T> {
pub fn new(map: IndexMap<String, T>) -> Self {
KwArgs(map)
@@ -377,6 +393,15 @@ impl<T> IntoIterator for KwArgs<T> {
#[derive(Clone)]
pub struct PosArgs<T = PyObjectRef>(Vec<T>);
unsafe impl<T> Traverse for PosArgs<T>
where
T: Traverse,
{
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.0.traverse(tracer_fn)
}
}
impl<T> PosArgs<T> {
pub fn new(args: Vec<T>) -> Self {
Self(args)
@@ -461,6 +486,18 @@ pub enum OptionalArg<T = PyObjectRef> {
Missing,
}
unsafe impl<T> Traverse for OptionalArg<T>
where
T: Traverse,
{
fn traverse(&self, tracer_fn: &mut TraverseFn) {
match self {
OptionalArg::Present(ref o) => o.traverse(tracer_fn),
OptionalArg::Missing => (),
}
}
}
impl OptionalArg<PyObjectRef> {
pub fn unwrap_or_none(self, vm: &VirtualMachine) -> PyObjectRef {
self.unwrap_or_else(|| vm.ctx.none())

View File

@@ -9,7 +9,7 @@ use crate::{
// Python/getargs.c
/// any bytes-like object. Like the `y*` format code for `PyArg_Parse` in CPython.
#[derive(Debug)]
#[derive(Debug, Traverse)]
pub struct ArgBytesLike(PyBuffer);
impl PyObject {
@@ -82,7 +82,7 @@ impl<'a> TryFromBorrowedObject<'a> for ArgBytesLike {
}
/// A memory buffer, read-write access. Like the `w*` format code for `PyArg_Parse` in CPython.
#[derive(Debug)]
#[derive(Debug, Traverse)]
pub struct ArgMemoryBuffer(PyBuffer);
impl ArgMemoryBuffer {

View File

@@ -130,7 +130,7 @@ impl TryFromObject for ArgIntoBool {
}
// Implement ArgIndex to separate between "true" int and int generated by index
#[derive(Debug)]
#[derive(Debug, Traverse)]
#[repr(transparent)]
pub struct ArgIndex {
value: PyIntRef,

View File

@@ -3,15 +3,17 @@ use crate::{
builtins::{iter::PySequenceIterator, PyDict, PyDictRef},
convert::ToPyObject,
identifier,
object::{Traverse, TraverseFn},
protocol::{PyIter, PyIterIter, PyMapping, PyMappingMethods},
types::{AsMapping, GenericMethod},
AsObject, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine,
};
use std::{borrow::Borrow, marker::PhantomData, ops::Deref};
#[derive(Clone)]
#[derive(Clone, Traverse)]
pub struct ArgCallable {
obj: PyObjectRef,
#[pytraverse(skip)]
call: GenericMethod,
}
@@ -75,6 +77,12 @@ pub struct ArgIterable<T = PyObjectRef> {
_item: PhantomData<T>,
}
unsafe impl<T: Traverse> Traverse for ArgIterable<T> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.iterable.traverse(tracer_fn)
}
}
impl<T> ArgIterable<T> {
/// Returns an iterator over this sequence of objects.
///
@@ -110,9 +118,10 @@ where
}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Traverse)]
pub struct ArgMapping {
obj: PyObjectRef,
#[pytraverse(skip)]
methods: &'static PyMappingMethods,
}
@@ -187,6 +196,12 @@ impl TryFromObject for ArgMapping {
#[derive(Clone)]
pub struct ArgSequence<T = PyObjectRef>(Vec<T>);
unsafe impl<T: Traverse> Traverse for ArgSequence<T> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.0.traverse(tracer_fn);
}
}
impl<T> ArgSequence<T> {
#[inline(always)]
pub fn into_vec(self) -> Vec<T> {

View File

@@ -10,12 +10,13 @@
//!
//! PyRef<PyWeak> may looking like to be called as PyObjectWeak by the rule,
//! but not to do to remember it is a PyRef object.
use super::{
ext::{AsObject, PyRefExact, PyResult},
payload::PyObjectPayload,
PyAtomicRef,
};
use crate::object::traverse::{Traverse, TraverseFn};
use crate::object::traverse_object::PyObjVTable;
use crate::{
builtins::{PyDictRef, PyType, PyTypeRef},
common::{
@@ -73,52 +74,42 @@ use std::{
/// A type to just represent "we've erased the type of this object, cast it before you use it"
#[derive(Debug)]
struct Erased;
pub(super) struct Erased;
struct PyObjVTable {
drop_dealloc: unsafe fn(*mut PyObject),
debug: unsafe fn(&PyObject, &mut fmt::Formatter) -> fmt::Result,
}
unsafe fn drop_dealloc_obj<T: PyObjectPayload>(x: *mut PyObject) {
pub(super) unsafe fn drop_dealloc_obj<T: PyObjectPayload>(x: *mut PyObject) {
drop(Box::from_raw(x as *mut PyInner<T>));
}
unsafe fn debug_obj<T: PyObjectPayload>(x: &PyObject, f: &mut fmt::Formatter) -> fmt::Result {
pub(super) unsafe fn debug_obj<T: PyObjectPayload>(
x: &PyObject,
f: &mut fmt::Formatter,
) -> fmt::Result {
let x = &*(x as *const PyObject as *const PyInner<T>);
fmt::Debug::fmt(x, f)
}
impl PyObjVTable {
pub fn of<T: PyObjectPayload>() -> &'static Self {
struct Helper<T: PyObjectPayload>(PhantomData<T>);
trait VtableHelper {
const VTABLE: PyObjVTable;
}
impl<T: PyObjectPayload> VtableHelper for Helper<T> {
const VTABLE: PyObjVTable = PyObjVTable {
drop_dealloc: drop_dealloc_obj::<T>,
debug: debug_obj::<T>,
};
}
&Helper::<T>::VTABLE
}
/// Call `try_trace` on payload
pub(super) unsafe fn try_trace_obj<T: PyObjectPayload>(x: &PyObject, tracer_fn: &mut TraverseFn) {
let x = &*(x as *const PyObject as *const PyInner<T>);
let payload = &x.payload;
payload.try_traverse(tracer_fn)
}
/// This is an actual python object. It consists of a `typ` which is the
/// python class, and carries some rust payload optionally. This rust
/// payload can be a rust float or rust int in case of float and int objects.
#[repr(C)]
struct PyInner<T> {
ref_count: RefCount,
pub(super) struct PyInner<T> {
pub(super) ref_count: RefCount,
// TODO: move typeid into vtable once TypeId::of is const
typeid: TypeId,
vtable: &'static PyObjVTable,
pub(super) typeid: TypeId,
pub(super) vtable: &'static PyObjVTable,
typ: PyAtomicRef<PyType>, // __class__ member
dict: Option<InstanceDict>,
weak_list: WeakRefList,
slots: Box<[PyRwLock<Option<PyObjectRef>>]>,
pub(super) typ: PyAtomicRef<PyType>, // __class__ member
pub(super) dict: Option<InstanceDict>,
pub(super) weak_list: WeakRefList,
pub(super) slots: Box<[PyRwLock<Option<PyObjectRef>>]>,
payload: T,
pub(super) payload: T,
}
impl<T: fmt::Debug> fmt::Debug for PyInner<T> {
@@ -127,7 +118,23 @@ impl<T: fmt::Debug> fmt::Debug for PyInner<T> {
}
}
struct WeakRefList {
unsafe impl<T: PyObjectPayload> Traverse for Py<T> {
/// DO notice that call `trace` on `Py<T>` means apply `tracer_fn` on `Py<T>`'s children,
/// not like call `trace` on `PyRef<T>` which apply `tracer_fn` on `PyRef<T>` itself
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.0.traverse(tracer_fn)
}
}
unsafe impl Traverse for PyObject {
/// DO notice that call `trace` on `PyObject` means apply `tracer_fn` on `PyObject`'s children,
/// not like call `trace` on `PyObjectRef` which apply `tracer_fn` on `PyObjectRef` itself
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.0.traverse(tracer_fn)
}
}
pub(super) struct WeakRefList {
inner: OncePtr<PyMutex<WeakListInner>>,
}
@@ -393,8 +400,8 @@ impl Py<PyWeak> {
}
#[derive(Debug)]
struct InstanceDict {
d: PyRwLock<PyDictRef>,
pub(super) struct InstanceDict {
pub(super) d: PyRwLock<PyDictRef>,
}
impl From<PyDictRef> for InstanceDict {

View File

@@ -1,7 +1,10 @@
mod core;
mod ext;
mod payload;
mod traverse;
mod traverse_object;
pub use self::core::*;
pub use self::ext::*;
pub use self::payload::*;
pub use traverse::{MaybeTraverse, Traverse, TraverseFn};

View File

@@ -1,4 +1,4 @@
use super::{Py, PyObjectRef, PyRef, PyResult};
use crate::object::{MaybeTraverse, Py, PyObjectRef, PyRef, PyResult};
use crate::{
builtins::{PyBaseExceptionRef, PyType, PyTypeRef},
types::PyTypeFlags,
@@ -16,7 +16,9 @@ cfg_if::cfg_if! {
}
}
pub trait PyPayload: std::fmt::Debug + PyThreadingConstraint + Sized + 'static {
pub trait PyPayload:
std::fmt::Debug + MaybeTraverse + PyThreadingConstraint + Sized + 'static
{
fn class(ctx: &Context) -> &'static Py<PyType>;
#[inline]
@@ -73,7 +75,7 @@ pub trait PyPayload: std::fmt::Debug + PyThreadingConstraint + Sized + 'static {
}
pub trait PyObjectPayload:
std::any::Any + std::fmt::Debug + PyThreadingConstraint + 'static
std::any::Any + std::fmt::Debug + MaybeTraverse + PyThreadingConstraint + 'static
{
}

231
vm/src/object/traverse.rs Normal file
View File

@@ -0,0 +1,231 @@
use std::ptr::NonNull;
use rustpython_common::lock::{PyMutex, PyRwLock};
use crate::{function::Either, object::PyObjectPayload, AsObject, PyObject, PyObjectRef, PyRef};
pub type TraverseFn<'a> = dyn FnMut(&PyObject) + 'a;
/// This trait is used as a "Optional Trait"(I 'd like to use `Trace?` but it's not allowed yet) for PyObjectPayload type
///
/// impl for PyObjectPayload, `pyclass` proc macro will handle the actual dispatch if type impl `Trace`
/// Every PyObjectPayload impl `MaybeTrace`, which may or may not be traceable
pub trait MaybeTraverse {
/// if is traceable, will be used by vtable to determine
const IS_TRACE: bool = false;
// if this type is traceable, then call with tracer_fn, default to do nothing
fn try_traverse(&self, traverse_fn: &mut TraverseFn);
}
/// Type that need traverse it's children should impl `Traverse`(Not `MaybeTraverse`)
/// # Safety
/// impl `traverse()` with caution! Following those guideline so traverse doesn't cause memory error!:
/// - Make sure that every owned object(Every PyObjectRef/PyRef) is called with traverse_fn **at most once**.
/// If some field is not called, the worst results is just memory leak,
/// but if some field is called repeatly, panic and deadlock can happen.
///
/// - _**DO NOT**_ clone a `PyObjectRef` or `Pyef<T>` in `traverse()`
pub unsafe trait Traverse {
/// impl `traverse()` with caution! Following those guideline so traverse doesn't cause memory error!:
/// - Make sure that every owned object(Every PyObjectRef/PyRef) is called with traverse_fn **at most once**.
/// If some field is not called, the worst results is just memory leak,
/// but if some field is called repeatly, panic and deadlock can happen.
///
/// - _**DO NOT**_ clone a `PyObjectRef` or `Pyef<T>` in `traverse()`
fn traverse(&self, traverse_fn: &mut TraverseFn);
}
unsafe impl Traverse for PyObjectRef {
fn traverse(&self, traverse_fn: &mut TraverseFn) {
traverse_fn(self)
}
}
unsafe impl<T: PyObjectPayload> Traverse for PyRef<T> {
fn traverse(&self, traverse_fn: &mut TraverseFn) {
traverse_fn(self.as_object())
}
}
unsafe impl Traverse for () {
fn traverse(&self, _traverse_fn: &mut TraverseFn) {}
}
unsafe impl<T: Traverse> Traverse for Option<T> {
#[inline]
fn traverse(&self, traverse_fn: &mut TraverseFn) {
if let Some(v) = self {
v.traverse(traverse_fn);
}
}
}
unsafe impl<T> Traverse for [T]
where
T: Traverse,
{
#[inline]
fn traverse(&self, traverse_fn: &mut TraverseFn) {
for elem in self {
elem.traverse(traverse_fn);
}
}
}
unsafe impl<T> Traverse for Box<[T]>
where
T: Traverse,
{
#[inline]
fn traverse(&self, traverse_fn: &mut TraverseFn) {
for elem in &**self {
elem.traverse(traverse_fn);
}
}
}
unsafe impl<T> Traverse for Vec<T>
where
T: Traverse,
{
#[inline]
fn traverse(&self, traverse_fn: &mut TraverseFn) {
for elem in self {
elem.traverse(traverse_fn);
}
}
}
unsafe impl<T: Traverse> Traverse for PyRwLock<T> {
#[inline]
fn traverse(&self, traverse_fn: &mut TraverseFn) {
// if can't get a lock, this means something else is holding the lock,
// but since gc stopped the world, during gc the lock is always held
// so it is safe to ignore those in gc
if let Some(inner) = self.try_read_recursive() {
inner.traverse(traverse_fn)
}
}
}
/// Safety: We can't hold lock during traverse it's child because it may cause deadlock.
/// TODO(discord9): check if this is thread-safe to do
/// (Outside of gc phase, only incref/decref will call trace,
/// and refcnt is atomic, so it should be fine?)
unsafe impl<T: Traverse> Traverse for PyMutex<T> {
#[inline]
fn traverse(&self, traverse_fn: &mut TraverseFn) {
let mut chs: Vec<NonNull<PyObject>> = Vec::new();
if let Some(obj) = self.try_lock() {
obj.traverse(&mut |ch| {
chs.push(NonNull::from(ch));
})
}
chs.iter()
.map(|ch| {
// Safety: during gc, this should be fine, because nothing should write during gc's tracing?
let ch = unsafe { ch.as_ref() };
traverse_fn(ch);
})
.count();
}
}
macro_rules! trace_tuple {
($(($NAME: ident, $NUM: tt)),*) => {
unsafe impl<$($NAME: Traverse),*> Traverse for ($($NAME),*) {
#[inline]
fn traverse(&self, traverse_fn: &mut TraverseFn) {
$(
self.$NUM.traverse(traverse_fn);
)*
}
}
};
}
unsafe impl<A: Traverse, B: Traverse> Traverse for Either<A, B> {
#[inline]
fn traverse(&self, tracer_fn: &mut TraverseFn) {
match self {
Either::A(a) => a.traverse(tracer_fn),
Either::B(b) => b.traverse(tracer_fn),
}
}
}
// only tuple with 12 elements or less is supported,
// because long tuple is extremly rare in almost every case
unsafe impl<A: Traverse> Traverse for (A,) {
#[inline]
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.0.traverse(tracer_fn);
}
}
trace_tuple!((A, 0), (B, 1));
trace_tuple!((A, 0), (B, 1), (C, 2));
trace_tuple!((A, 0), (B, 1), (C, 2), (D, 3));
trace_tuple!((A, 0), (B, 1), (C, 2), (D, 3), (E, 4));
trace_tuple!((A, 0), (B, 1), (C, 2), (D, 3), (E, 4), (F, 5));
trace_tuple!((A, 0), (B, 1), (C, 2), (D, 3), (E, 4), (F, 5), (G, 6));
trace_tuple!(
(A, 0),
(B, 1),
(C, 2),
(D, 3),
(E, 4),
(F, 5),
(G, 6),
(H, 7)
);
trace_tuple!(
(A, 0),
(B, 1),
(C, 2),
(D, 3),
(E, 4),
(F, 5),
(G, 6),
(H, 7),
(I, 8)
);
trace_tuple!(
(A, 0),
(B, 1),
(C, 2),
(D, 3),
(E, 4),
(F, 5),
(G, 6),
(H, 7),
(I, 8),
(J, 9)
);
trace_tuple!(
(A, 0),
(B, 1),
(C, 2),
(D, 3),
(E, 4),
(F, 5),
(G, 6),
(H, 7),
(I, 8),
(J, 9),
(K, 10)
);
trace_tuple!(
(A, 0),
(B, 1),
(C, 2),
(D, 3),
(E, 4),
(F, 5),
(G, 6),
(H, 7),
(I, 8),
(J, 9),
(K, 10),
(L, 11)
);

View File

@@ -0,0 +1,78 @@
use std::{fmt, marker::PhantomData};
use crate::{
object::{
debug_obj, drop_dealloc_obj, try_trace_obj, Erased, InstanceDict, PyInner, PyObjectPayload,
},
PyObject,
};
use super::{Traverse, TraverseFn};
pub(in crate::object) struct PyObjVTable {
pub(in crate::object) drop_dealloc: unsafe fn(*mut PyObject),
pub(in crate::object) debug: unsafe fn(&PyObject, &mut fmt::Formatter) -> fmt::Result,
pub(in crate::object) trace: Option<unsafe fn(&PyObject, &mut TraverseFn)>,
}
impl PyObjVTable {
pub fn of<T: PyObjectPayload>() -> &'static Self {
struct Helper<T: PyObjectPayload>(PhantomData<T>);
trait VtableHelper {
const VTABLE: PyObjVTable;
}
impl<T: PyObjectPayload> VtableHelper for Helper<T> {
const VTABLE: PyObjVTable = PyObjVTable {
drop_dealloc: drop_dealloc_obj::<T>,
debug: debug_obj::<T>,
trace: {
if T::IS_TRACE {
Some(try_trace_obj::<T>)
} else {
None
}
},
};
}
&Helper::<T>::VTABLE
}
}
unsafe impl Traverse for InstanceDict {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.d.traverse(tracer_fn)
}
}
unsafe impl Traverse for PyInner<Erased> {
/// Because PyObject hold a `PyInner<Erased>`, so we need to trace it
fn traverse(&self, tracer_fn: &mut TraverseFn) {
// 1. trace `dict` and `slots` field(`typ` can't trace for it's a AtomicRef while is leaked by design)
// 2. call vtable's trace function to trace payload
// self.typ.trace(tracer_fn);
self.dict.traverse(tracer_fn);
// weak_list keeps a *pointer* to a struct for maintaince weak ref, so no ownership, no trace
self.slots.traverse(tracer_fn);
if let Some(f) = self.vtable.trace {
unsafe {
let zelf = &*(self as *const PyInner<Erased> as *const PyObject);
f(zelf, tracer_fn)
}
};
}
}
unsafe impl<T: PyObjectPayload> Traverse for PyInner<T> {
/// Type is known, so we can call `try_trace` directly instead of using erased type vtable
fn traverse(&self, tracer_fn: &mut TraverseFn) {
// 1. trace `dict` and `slots` field(`typ` can't trace for it's a AtomicRef while is leaked by design)
// 2. call corrsponding `try_trace` function to trace payload
// (No need to call vtable's trace function because we already know the type)
// self.typ.trace(tracer_fn);
self.dict.traverse(tracer_fn);
// weak_list keeps a *pointer* to a struct for maintaince weak ref, so no ownership, no trace
self.slots.traverse(tracer_fn);
T::try_traverse(&self.payload, tracer_fn);
}
}

View File

@@ -32,10 +32,12 @@ impl Debug for BufferMethods {
}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Traverse)]
pub struct PyBuffer {
pub obj: PyObjectRef,
#[pytraverse(skip)]
pub desc: BufferDescriptor,
#[pytraverse(skip)]
methods: &'static BufferMethods,
}

View File

@@ -1,6 +1,7 @@
use crate::{
builtins::iter::PySequenceIterator,
convert::{ToPyObject, ToPyResult},
object::{Traverse, TraverseFn},
AsObject, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine,
};
use std::borrow::Borrow;
@@ -14,6 +15,12 @@ pub struct PyIter<O = PyObjectRef>(O)
where
O: Borrow<PyObject>;
unsafe impl<O: Borrow<PyObject>> Traverse for PyIter<O> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.0.borrow().traverse(tracer_fn);
}
}
impl PyIter<PyObjectRef> {
pub fn check(obj: &PyObject) -> bool {
obj.class()
@@ -149,6 +156,16 @@ pub enum PyIterReturn<T = PyObjectRef> {
StopIteration(Option<PyObjectRef>),
}
unsafe impl<T: Traverse> Traverse for PyIterReturn<T> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
match self {
PyIterReturn::Return(r) => r.traverse(tracer_fn),
PyIterReturn::StopIteration(Some(obj)) => obj.traverse(tracer_fn),
_ => (),
}
}
}
impl PyIterReturn {
pub fn from_pyresult(result: PyResult, vm: &VirtualMachine) -> PyResult<Self> {
match result {
@@ -212,6 +229,15 @@ where
_phantom: std::marker::PhantomData<T>,
}
unsafe impl<'a, T, O> Traverse for PyIterIter<'a, T, O>
where
O: Traverse + Borrow<PyObject>,
{
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.obj.traverse(tracer_fn)
}
}
impl<'a, T, O> PyIterIter<'a, T, O>
where
O: Borrow<PyObject>,

View File

@@ -5,6 +5,7 @@ use crate::{
PyDict, PyStrInterned,
},
convert::ToPyResult,
object::{Traverse, TraverseFn},
AsObject, PyObject, PyObjectRef, PyResult, VirtualMachine,
};
use crossbeam_utils::atomic::AtomicCell;
@@ -62,6 +63,12 @@ pub struct PyMapping<'a> {
pub methods: &'static PyMappingMethods,
}
unsafe impl Traverse for PyMapping<'_> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.obj.traverse(tracer_fn)
}
}
impl AsRef<PyObject> for PyMapping<'_> {
#[inline(always)]
fn as_ref(&self) -> &PyObject {

View File

@@ -6,6 +6,7 @@ use crate::{
builtins::{int, PyByteArray, PyBytes, PyComplex, PyFloat, PyInt, PyIntRef, PyStr},
common::int::bytes_to_int,
function::ArgBytesLike,
object::{Traverse, TraverseFn},
stdlib::warnings,
AsObject, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject,
VirtualMachine,
@@ -426,6 +427,12 @@ impl PyNumberSlots {
#[derive(Copy, Clone)]
pub struct PyNumber<'a>(&'a PyObject);
unsafe impl Traverse for PyNumber<'_> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.0.traverse(tracer_fn)
}
}
impl<'a> Deref for PyNumber<'a> {
type Target = PyObject;

View File

@@ -2,6 +2,7 @@ use crate::{
builtins::{type_::PointerSlot, PyList, PyListRef, PySlice, PyTuple, PyTupleRef},
convert::ToPyObject,
function::PyArithmeticValue,
object::{Traverse, TraverseFn},
protocol::{PyMapping, PyNumberBinaryOp},
AsObject, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
};
@@ -65,6 +66,12 @@ pub struct PySequence<'a> {
pub methods: &'static PySequenceMethods,
}
unsafe impl Traverse for PySequence<'_> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.obj.traverse(tracer_fn)
}
}
impl<'a> PySequence<'a> {
#[inline]
pub fn with_methods(obj: &'a PyObject, methods: &'static PySequenceMethods) -> Self {