Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add basic support for generic structs and enums #483

Merged
merged 9 commits into from Jul 1, 2021
96 changes: 82 additions & 14 deletions structopt-derive/src/lib.rs
Expand Up @@ -559,10 +559,10 @@ fn gen_augment_clap_enum(
}
}

fn gen_from_clap_enum(name: &Ident) -> TokenStream {
fn gen_from_clap_enum() -> TokenStream {
quote! {
fn from_clap(matches: &::structopt::clap::ArgMatches) -> Self {
<#name as ::structopt::StructOptInternal>::from_subcommand(matches.subcommand())
<Self as ::structopt::StructOptInternal>::from_subcommand(matches.subcommand())
.expect("structopt misuse: You likely tried to #[flatten] a struct \
that contains #[subcommand]. This is forbidden.")
}
Expand Down Expand Up @@ -736,9 +736,9 @@ fn gen_from_subcommand(
}

#[cfg(feature = "paw")]
fn gen_paw_impl(name: &Ident) -> TokenStream {
fn gen_paw_impl(impl_generics: &ImplGenerics, name: &Ident, ty_generics: &TypeGenerics, where_clause: &TokenStream) -> TokenStream {
quote! {
impl ::structopt::paw::ParseArgs for #name {
impl #impl_generics ::structopt::paw::ParseArgs for #name #ty_generics #where_clause {
type Error = std::io::Error;

fn parse_args() -> std::result::Result<Self, Self::Error> {
Expand All @@ -748,19 +748,83 @@ fn gen_paw_impl(name: &Ident) -> TokenStream {
}
}
#[cfg(not(feature = "paw"))]
fn gen_paw_impl(_: &Ident) -> TokenStream {
fn gen_paw_impl(_: &ImplGenerics, _: &Ident, _: &TypeGenerics, _: &TokenStream) -> TokenStream {
TokenStream::new()
}

fn split_structopt_generics_for_impl(generics: &Generics) -> (ImplGenerics, TypeGenerics, TokenStream) {
use syn::{ token::Add, TypeParamBound::Trait };

fn path_is_structop(path: &Path) -> bool {
njeffords marked this conversation as resolved.
Show resolved Hide resolved
path.segments.last().unwrap().ident == "StructOpt"
}

fn type_param_bounds_contains_structop(bounds: &Punctuated<TypeParamBound, Add>) -> bool {
njeffords marked this conversation as resolved.
Show resolved Hide resolved
for bound in bounds {
if let Trait(bound) = bound {
if path_is_structop(&bound.path) {
return true;
}
}
}
return false;
}

let mut trait_bound_amendments = TokenStream::new();

for param in &generics.params {
if let GenericParam::Type(param) = param {
let param_ident = &param.ident;
if type_param_bounds_contains_structop(&param.bounds) {
if !trait_bound_amendments.is_empty() {
trait_bound_amendments.extend(quote!{ , });
}
trait_bound_amendments.extend(quote!{ #param_ident : ::structopt::StructOptInternal });
}
}
}

if let Some(where_clause) = &generics.where_clause {
for predicate in &where_clause.predicates {
if let WherePredicate::Type(predicate) = predicate {
let predicate_bounded_ty = &predicate.bounded_ty;
if type_param_bounds_contains_structop(&predicate.bounds) {
if !trait_bound_amendments.is_empty() {
trait_bound_amendments.extend(quote!{ , });
}
trait_bound_amendments.extend(quote!{ #predicate_bounded_ty : ::structopt::StructOptInternal });
}
}
}
}

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let where_clause = if !trait_bound_amendments.is_empty() {
if let Some(where_clause) = where_clause {
quote!{ #where_clause, #trait_bound_amendments }
} else {
quote!{ where #trait_bound_amendments }
}
} else {
quote!{ #where_clause }
};

(impl_generics, ty_generics, where_clause)
}

fn impl_structopt_for_struct(
name: &Ident,
fields: &Punctuated<Field, Comma>,
attrs: &[Attribute],
generics: &Generics,
) -> TokenStream {
let (impl_generics, ty_generics, where_clause) = split_structopt_generics_for_impl(&generics);

let basic_clap_app_gen = gen_clap_struct(attrs);
let augment_clap = gen_augment_clap(fields, &basic_clap_app_gen.attrs);
let from_clap = gen_from_clap(name, fields, &basic_clap_app_gen.attrs);
let paw_impl = gen_paw_impl(name);
let paw_impl = gen_paw_impl(&impl_generics, name, &ty_generics, &where_clause);

let clap_tokens = basic_clap_app_gen.tokens;
quote! {
Expand All @@ -778,7 +842,7 @@ fn impl_structopt_for_struct(
)]
#[deny(clippy::correctness)]
#[allow(dead_code, unreachable_code)]
impl ::structopt::StructOpt for #name {
impl #impl_generics ::structopt::StructOpt for #name #ty_generics #where_clause {
#clap_tokens
#from_clap
}
Expand All @@ -797,7 +861,7 @@ fn impl_structopt_for_struct(
)]
#[deny(clippy::correctness)]
#[allow(dead_code, unreachable_code)]
impl ::structopt::StructOptInternal for #name {
impl #impl_generics ::structopt::StructOptInternal for #name #ty_generics #where_clause {
#augment_clap
fn is_subcommand() -> bool { false }
}
Expand All @@ -810,15 +874,19 @@ fn impl_structopt_for_enum(
name: &Ident,
variants: &Punctuated<Variant, Comma>,
attrs: &[Attribute],
generics: &Generics,
) -> TokenStream {

let (impl_generics, ty_generics, where_clause) = split_structopt_generics_for_impl(&generics);

let basic_clap_app_gen = gen_clap_enum(attrs);
let clap_tokens = basic_clap_app_gen.tokens;
let attrs = basic_clap_app_gen.attrs;

let augment_clap = gen_augment_clap_enum(variants, &attrs);
let from_clap = gen_from_clap_enum(name);
let from_clap = gen_from_clap_enum();
let from_subcommand = gen_from_subcommand(name, variants, &attrs);
let paw_impl = gen_paw_impl(name);
let paw_impl = gen_paw_impl(&impl_generics, name, &ty_generics, &where_clause);

quote! {
#[allow(unknown_lints)]
Expand All @@ -834,7 +902,7 @@ fn impl_structopt_for_enum(
clippy::cargo
)]
#[deny(clippy::correctness)]
impl ::structopt::StructOpt for #name {
impl #impl_generics ::structopt::StructOpt for #name #ty_generics #where_clause {
#clap_tokens
#from_clap
}
Expand All @@ -853,7 +921,7 @@ fn impl_structopt_for_enum(
)]
#[deny(clippy::correctness)]
#[allow(dead_code, unreachable_code)]
impl ::structopt::StructOptInternal for #name {
impl #impl_generics ::structopt::StructOptInternal for #name #ty_generics #where_clause {
#augment_clap
#from_subcommand
fn is_subcommand() -> bool { true }
Expand Down Expand Up @@ -885,8 +953,8 @@ fn impl_structopt(input: &DeriveInput) -> TokenStream {
Struct(DataStruct {
fields: syn::Fields::Named(ref fields),
..
}) => impl_structopt_for_struct(struct_name, &fields.named, &input.attrs),
Enum(ref e) => impl_structopt_for_enum(struct_name, &e.variants, &input.attrs),
}) => impl_structopt_for_struct(struct_name, &fields.named, &input.attrs, &input.generics),
Enum(ref e) => impl_structopt_for_enum(struct_name, &e.variants, &input.attrs, &input.generics),
_ => abort_call_site!("structopt only supports non-tuple structs and enums"),
}
}
86 changes: 86 additions & 0 deletions tests/generics.rs
@@ -0,0 +1,86 @@

use structopt::StructOpt;

#[test]
fn generic_struct_flatten() {

#[derive(StructOpt,PartialEq,Debug)]
struct Inner{
pub answer: isize
}

#[derive(StructOpt,PartialEq,Debug)]
struct Outer<T:StructOpt>{
#[structopt(flatten)]
pub inner: T
}

assert_eq!(
Outer{inner: Inner{ answer: 42 }},
Outer::from_iter(&[ "--answer", "42" ])
)
}

#[test]
fn generic_struct_flatten_w_where_clause() {

#[derive(StructOpt,PartialEq,Debug)]
struct Inner{
pub answer: isize
}

#[derive(StructOpt,PartialEq,Debug)]
struct Outer<T> where T:StructOpt {
#[structopt(flatten)]
pub inner: T
}

assert_eq!(
Outer{inner: Inner{ answer: 42 }},
Outer::from_iter(&[ "--answer", "42" ])
)
}

#[test]
fn generic_enum() {

#[derive(StructOpt,PartialEq,Debug)]
struct Inner{
pub answer: isize
}

#[derive(StructOpt,PartialEq,Debug)]
enum GenericEnum<T: StructOpt> {

Start(T),
Stop,
}

assert_eq!(
GenericEnum::Start(Inner{answer: 42}),
GenericEnum::from_iter(&[ "test", "start", "42" ])
)

}

#[test]
fn generic_enum_w_where_clause() {

#[derive(StructOpt,PartialEq,Debug)]
struct Inner{
pub answer: isize
}

#[derive(StructOpt,PartialEq,Debug)]
enum GenericEnum<T> where T: StructOpt {

Start(T),
Stop,
}

assert_eq!(
GenericEnum::Start(Inner{answer: 42}),
GenericEnum::from_iter(&[ "test", "start", "42" ])
)

}