Allow tuple structs for derive(FromArgs)

This commit is contained in:
Noah
2021-02-01 12:06:18 -06:00
parent ce8b5eed90
commit 561f4ee779
2 changed files with 32 additions and 47 deletions

View File

@@ -1,10 +1,8 @@
use crate::util::path_eq;
use crate::Diagnostic;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{
parse_quote, Attribute, Data, DeriveInput, Expr, Field, Fields, Ident, Lit, Meta, NestedMeta,
};
use quote::{quote, ToTokens};
use syn::{parse_quote, Attribute, Data, DeriveInput, Expr, Field, Ident, Lit, Meta, NestedMeta};
/// The kind of the python parameter, this corresponds to the value of Parameter.kind
/// (https://docs.python.org/3/library/inspect.html#inspect.Parameter.kind)
@@ -121,7 +119,7 @@ impl ArgAttribute {
}
}
fn generate_field(field: &Field) -> Result<TokenStream2, Diagnostic> {
fn generate_field((i, field): (usize, &Field)) -> Result<TokenStream2, Diagnostic> {
let mut pyarg_attrs = field
.attrs
.iter()
@@ -139,19 +137,26 @@ fn generate_field(field: &Field) -> Result<TokenStream2, Diagnostic> {
bail_span!(field, "Multiple pyarg attributes on field");
};
let name = field.ident.as_ref().unwrap();
let namestring = name.to_string();
if namestring.starts_with("_phantom") {
let name = field.ident.as_ref();
let namestring = name.map(Ident::to_string);
if matches!(&namestring, Some(s) if s.starts_with("_phantom")) {
return Ok(quote! {
#name: ::std::marker::PhantomData,
});
}
let fieldname = match name {
Some(id) => id.to_token_stream(),
None => syn::Index::from(i).into_token_stream(),
};
if let ParameterKind::Flatten = attr.kind {
return Ok(quote! {
#name: ::rustpython_vm::function::FromArgs::from_args(vm, args)?,
#fieldname: ::rustpython_vm::function::FromArgs::from_args(vm, args)?,
});
}
let pyname = attr.name.unwrap_or(namestring);
let pyname = attr
.name
.or(namestring)
.ok_or_else(|| err_span!(field, "field in tuple struct must have name attribute"))?;
let middle = quote! {
.map(|x| ::rustpython_vm::pyobject::TryFromObject::try_from_object(vm, x)).transpose()?
};
@@ -179,17 +184,17 @@ fn generate_field(field: &Field) -> Result<TokenStream2, Diagnostic> {
let file_output = match attr.kind {
ParameterKind::PositionalOnly => {
quote! {
#name: args.take_positional()#middle#ending,
#fieldname: args.take_positional()#middle#ending,
}
}
ParameterKind::PositionalOrKeyword => {
quote! {
#name: args.take_positional_keyword(#pyname)#middle#ending,
#fieldname: args.take_positional_keyword(#pyname)#middle#ending,
}
}
ParameterKind::KeywordOnly => {
quote! {
#name: args.take_keyword(#pyname)#middle#ending,
#fieldname: args.take_keyword(#pyname)#middle#ending,
}
}
ParameterKind::Flatten => unreachable!(),
@@ -199,15 +204,12 @@ fn generate_field(field: &Field) -> Result<TokenStream2, Diagnostic> {
pub fn impl_from_args(input: DeriveInput) -> Result<TokenStream2, Diagnostic> {
let fields = match input.data {
Data::Struct(syn::DataStruct {
fields: Fields::Named(fields),
..
}) => fields
.named
Data::Struct(syn::DataStruct { fields, .. }) => fields
.iter()
.enumerate()
.map(generate_field)
.collect::<Result<TokenStream2, Diagnostic>>()?,
_ => bail_span!(input, "FromArgs input must be a struct with named fields"),
_ => bail_span!(input, "FromArgs input must be a struct"),
};
let name = input.ident;

View File

@@ -262,10 +262,7 @@ impl FromArgs for DirFd {
}
#[derive(FromArgs)]
struct FollowSymlinks {
#[pyarg(named, default = "true")]
follow_symlinks: bool,
}
struct FollowSymlinks(#[pyarg(named, name = "follow_symlinks", default = "true")] bool);
#[cfg(unix)]
use posix::bytes_as_osstr;
@@ -610,7 +607,7 @@ mod _os {
action: fn(fs::Metadata) -> bool,
vm: &VirtualMachine,
) -> PyResult<bool> {
let meta = fs_metadata(self.entry.path(), follow_symlinks.follow_symlinks)
let meta = fs_metadata(self.entry.path(), follow_symlinks.0)
.map_err(|err| err.into_pyexception(vm))?;
Ok(action(meta))
}
@@ -771,14 +768,7 @@ mod _os {
#[pyfunction]
fn lstat(file: Either<PyPathLike, i64>, dir_fd: DirFd, vm: &VirtualMachine) -> PyResult {
super::platform::stat(
file,
dir_fd,
FollowSymlinks {
follow_symlinks: false,
},
vm,
)
super::platform::stat(file, dir_fd, FollowSymlinks(false), vm)
}
#[pyfunction]
@@ -987,7 +977,7 @@ mod _os {
path,
&acc.into(),
&modif.into(),
if _follow_symlinks.follow_symlinks {
if _follow_symlinks.0 {
nix::sys::stat::UtimensatFlags::FollowSymlink
} else {
nix::sys::stat::UtimensatFlags::NoFollowSymlink
@@ -1494,10 +1484,7 @@ mod posix {
use std::os::redox::fs::MetadataExt;
let meta = match file {
Either::A(path) => fs_metadata(
make_path(vm, &path, &dir_fd)?,
follow_symlinks.follow_symlinks,
),
Either::A(path) => fs_metadata(make_path(vm, &path, &dir_fd)?, follow_symlinks.0),
Either::B(fno) => {
let file = rust_file(fno);
let res = file.metadata();
@@ -1593,7 +1580,7 @@ mod posix {
return Err(vm.new_os_error(String::from("Specified gid is not valid.")));
};
let flag = if follow_symlinks.follow_symlinks {
let flag = if follow_symlinks.0 {
nix::unistd::FchownatFlags::FollowSymlink
} else {
nix::unistd::FchownatFlags::NoFollowSymlink
@@ -1619,9 +1606,7 @@ mod posix {
uid,
gid,
DirFd(None),
FollowSymlinks {
follow_symlinks: false,
},
FollowSymlinks(false),
vm,
)
}
@@ -1634,9 +1619,7 @@ mod posix {
uid,
gid,
DirFd(None),
FollowSymlinks {
follow_symlinks: true,
},
FollowSymlinks(true),
vm,
)
}
@@ -1751,7 +1734,7 @@ mod posix {
let path = make_path(vm, &path, &dir_fd)?;
let body = move || {
use std::os::unix::fs::PermissionsExt;
let meta = fs_metadata(path, follow_symlinks.follow_symlinks)?;
let meta = fs_metadata(path, follow_symlinks.0)?;
let mut permissions = meta.permissions();
permissions.set_mode(mode);
fs::set_permissions(path, permissions)
@@ -2616,7 +2599,7 @@ mod nt {
let get_stats = move || -> io::Result<PyObjectRef> {
let meta = match file {
Either::A(path) => fs_metadata(path.path, follow_symlinks.follow_symlinks)?,
Either::A(path) => fs_metadata(path.path, follow_symlinks.0)?,
Either::B(fno) => {
let f = rust_file(fno);
let meta = f.metadata()?;
@@ -2662,7 +2645,7 @@ mod nt {
) -> PyResult<()> {
const S_IWRITE: u32 = 128;
let path = make_path(vm, &path, &dir_fd)?;
let metadata = if follow_symlinks.follow_symlinks {
let metadata = if follow_symlinks.0 {
fs::metadata(path)
} else {
fs::symlink_metadata(path)