pyclass macro to recognize Py/PyRef pattern

This commit is contained in:
Jeong YunWon
2023-03-16 18:55:46 +09:00
parent fd028253fe
commit 0f24d66234
5 changed files with 120 additions and 46 deletions

View File

@@ -105,39 +105,114 @@ pub(crate) fn impl_pyimpl(attr: AttributeArgs, item: Item) -> Result<TokenStream
Item::Impl(mut imp) => {
extract_items_into_context(&mut context, imp.items.iter_mut());
let ty = &imp.self_ty;
let (impl_ty, payload_guess) = match imp.self_ty.as_ref() {
syn::Type::Path(syn::TypePath {
path: syn::Path { segments, .. },
..
}) if segments.len() == 1 => {
let segment = &segments[0];
let payload_ty = if segment.ident == "Py" || segment.ident == "PyRef" {
match &segment.arguments {
syn::PathArguments::AngleBracketed(
syn::AngleBracketedGenericArguments { args, .. },
) if args.len() == 1 => {
let arg = &args[0];
match arg {
syn::GenericArgument::Type(syn::Type::Path(
syn::TypePath {
path: syn::Path { segments, .. },
..
},
)) if segments.len() == 1 => segments[0].ident.clone(),
_ => {
return Err(syn::Error::new_spanned(
segment,
"Py{Ref}<T> is expected but Py{Ref}<?> is found",
))
}
}
}
_ => {
return Err(syn::Error::new_spanned(
segment,
"Py{Ref}<T> is expected but Py{Ref}? is found",
))
}
}
} else {
if !matches!(segment.arguments, syn::PathArguments::None) {
return Err(syn::Error::new_spanned(
segment,
"PyImpl can only be implemented for Py{Ref}<T> or T",
));
}
segment.ident.clone()
};
(segment.ident.clone(), payload_ty)
}
_ => {
return Err(syn::Error::new_spanned(
imp.self_ty,
"PyImpl can only be implemented for Py{Ref}<T> or T",
))
}
};
let ExtractedImplAttrs {
payload: attr_payload,
with_impl,
flags,
with_slots,
} = extract_impl_attrs(attr, &Ident::new(&quote!(ty).to_string(), ty.span()))?;
} = extract_impl_attrs(attr, &impl_ty)?;
let payload_ty = attr_payload.unwrap_or(payload_guess);
let getset_impl = &context.getset_items;
let member_impl = &context.member_items;
let extend_impl = context.impl_extend_items.validate()?;
let slots_impl = context.extend_slots_items.validate()?;
let class_extensions = &context.class_extensions;
quote! {
#imp
impl ::rustpython_vm::class::PyClassImpl for #ty {
const TP_FLAGS: ::rustpython_vm::types::PyTypeFlags = #flags;
fn impl_extend_class(
let extra_methods = iter_chain![
parse_quote! {
fn __extend_py_class(
ctx: &::rustpython_vm::Context,
class: &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType>,
) {
#getset_impl
#member_impl
#extend_impl
#with_impl
#(#class_extensions)*
}
fn extend_slots(slots: &mut ::rustpython_vm::types::PyTypeSlots) {
#with_slots
},
parse_quote! {
fn __extend_slots(slots: &mut ::rustpython_vm::types::PyTypeSlots) {
#slots_impl
}
},
];
imp.items.extend(extra_methods);
let is_main_impl = impl_ty == payload_ty;
if is_main_impl {
quote! {
#imp
impl ::rustpython_vm::class::PyClassImpl for #payload_ty {
const TP_FLAGS: ::rustpython_vm::types::PyTypeFlags = #flags;
fn impl_extend_class(
ctx: &::rustpython_vm::Context,
class: &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType>,
) {
#impl_ty::__extend_py_class(ctx, class);
#with_impl
}
fn extend_slots(slots: &mut ::rustpython_vm::types::PyTypeSlots) {
#impl_ty::__extend_slots(slots);
#with_slots
}
}
}
} else {
imp.into_token_stream()
}
}
Item::Trait(mut trai) => {
@@ -1163,6 +1238,7 @@ impl MemberItemMeta {
}
struct ExtractedImplAttrs {
payload: Option<Ident>,
with_impl: TokenStream,
with_slots: TokenStream,
flags: TokenStream,
@@ -1182,6 +1258,7 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result<ExtractedImpl
}
}
}];
let mut payload = None;
for attr in attr {
match attr {
@@ -1191,18 +1268,19 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result<ExtractedImpl
let NestedMeta::Meta(Meta::Path(path)) = meta else {
bail_span!(meta, "#[pyclass(with(...))] arguments should be paths")
};
let (extend_class, extend_slots) = if path.is_ident("PyRef") {
// special handling for PyRef
(
quote!(PyRef::<Self>::impl_extend_class),
quote!(PyRef::<Self>::extend_slots),
)
} else {
(
quote!(<Self as #path>::__extend_py_class),
quote!(<Self as #path>::__extend_slots),
)
};
let (extend_class, extend_slots) =
if path.is_ident("PyRef") || path.is_ident("Py") {
// special handling for PyRef
(
quote!(#path::<Self>::__extend_py_class),
quote!(#path::<Self>::__extend_slots),
)
} else {
(
quote!(<Self as #path>::__extend_py_class),
quote!(<Self as #path>::__extend_slots),
)
};
let item_span = item.span().resolved_at(Span::call_site());
withs.push(quote_spanned! { path.span() =>
#extend_class(ctx, class);
@@ -1227,11 +1305,23 @@ fn extract_impl_attrs(attr: AttributeArgs, item: &Ident) -> Result<ExtractedImpl
bail_span!(path, "Unknown pyimpl attribute")
}
}
NestedMeta::Meta(Meta::NameValue(syn::MetaNameValue { path, lit, .. })) => {
if path.is_ident("payload") {
if let syn::Lit::Str(lit) = lit {
payload = Some(Ident::new(&lit.value(), lit.span()));
} else {
bail_span!(lit, "payload must be a string literal")
}
} else {
bail_span!(path, "Unknown pyimpl attribute")
}
}
attr => bail_span!(attr, "Unknown pyimpl attribute"),
}
}
Ok(ExtractedImplAttrs {
payload,
with_impl: quote! {
#(#withs)*
},