Skip to content

Commit

Permalink
propagate generics when deriving StructOpt
Browse files Browse the repository at this point in the history
  • Loading branch information
njeffords committed Jun 24, 2021
1 parent 2ba552a commit c174294
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 11 deletions.
90 changes: 79 additions & 11 deletions structopt-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,9 @@ fn gen_from_subcommand(
}

#[cfg(feature = "paw")]
fn gen_paw_impl(name: &Ident) -> TokenStream {
fn gen_paw_impl(impl_generics: &ImplGenerics, name: &Ident, ty_generics: &TypeGenerics, where_clause: &TokenStream) -> TokenStream {
quote! {
impl ::structopt::paw::ParseArgs for #name {
impl #impl_generics ::structopt::paw::ParseArgs for #name #ty_generics #where_clause {
type Error = std::io::Error;

fn parse_args() -> std::result::Result<Self, Self::Error> {
Expand All @@ -748,19 +748,83 @@ fn gen_paw_impl(name: &Ident) -> TokenStream {
}
}
#[cfg(not(feature = "paw"))]
fn gen_paw_impl(_: &Ident) -> TokenStream {
fn gen_paw_impl(_: &ImplGenerics, _: &Ident, _: &TypeGenerics, _: &TokenStream) -> TokenStream {
TokenStream::new()
}

fn split_structopt_generics_for_impl(generics: &Generics) -> (ImplGenerics, TypeGenerics, TokenStream) {
use syn::{ token::Add, TypeParamBound::Trait };

fn path_is_structop(path: &Path) -> bool {
path.segments.last().unwrap().ident == "StructOpt"
}

fn type_param_bounds_contains_structop(bounds: &Punctuated<TypeParamBound, Add>) -> bool {
for bound in bounds {
if let Trait(bound) = bound {
if path_is_structop(&bound.path) {
return true;
}
}
}
return false;
}

let mut trait_bound_amendments = TokenStream::new();

for param in &generics.params {
if let GenericParam::Type(param) = param {
let param_ident = &param.ident;
if type_param_bounds_contains_structop(&param.bounds) {
if !trait_bound_amendments.is_empty() {
trait_bound_amendments.extend(quote!{ , });
}
trait_bound_amendments.extend(quote!{ #param_ident : ::structopt::StructOptInternal });
}
}
}

if let Some(where_clause) = &generics.where_clause {
for predicate in &where_clause.predicates {
if let WherePredicate::Type(predicate) = predicate {
let predicate_bounded_ty = &predicate.bounded_ty;
if type_param_bounds_contains_structop(&predicate.bounds) {
if !trait_bound_amendments.is_empty() {
trait_bound_amendments.extend(quote!{ , });
}
trait_bound_amendments.extend(quote!{ #predicate_bounded_ty : ::structopt::StructOptInternal });
}
}
}
}

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let where_clause = if !trait_bound_amendments.is_empty() {
if let Some(where_clause) = where_clause {
quote!{ #where_clause, #trait_bound_amendments }
} else {
quote!{ where #trait_bound_amendments }
}
} else {
quote!{ #where_clause }
};

(impl_generics, ty_generics, where_clause)
}

fn impl_structopt_for_struct(
name: &Ident,
fields: &Punctuated<Field, Comma>,
attrs: &[Attribute],
generics: &Generics,
) -> TokenStream {
let (impl_generics, ty_generics, where_clause) = split_structopt_generics_for_impl(&generics);

let basic_clap_app_gen = gen_clap_struct(attrs);
let augment_clap = gen_augment_clap(fields, &basic_clap_app_gen.attrs);
let from_clap = gen_from_clap(name, fields, &basic_clap_app_gen.attrs);
let paw_impl = gen_paw_impl(name);
let paw_impl = gen_paw_impl(&impl_generics, name, &ty_generics, &where_clause);

let clap_tokens = basic_clap_app_gen.tokens;
quote! {
Expand All @@ -778,7 +842,7 @@ fn impl_structopt_for_struct(
)]
#[deny(clippy::correctness)]
#[allow(dead_code, unreachable_code)]
impl ::structopt::StructOpt for #name {
impl #impl_generics ::structopt::StructOpt for #name #ty_generics #where_clause {
#clap_tokens
#from_clap
}
Expand All @@ -797,7 +861,7 @@ fn impl_structopt_for_struct(
)]
#[deny(clippy::correctness)]
#[allow(dead_code, unreachable_code)]
impl ::structopt::StructOptInternal for #name {
impl #impl_generics ::structopt::StructOptInternal for #name #ty_generics #where_clause {
#augment_clap
fn is_subcommand() -> bool { false }
}
Expand All @@ -810,15 +874,19 @@ fn impl_structopt_for_enum(
name: &Ident,
variants: &Punctuated<Variant, Comma>,
attrs: &[Attribute],
generics: &Generics,
) -> TokenStream {

let (impl_generics, ty_generics, where_clause) = split_structopt_generics_for_impl(&generics);

let basic_clap_app_gen = gen_clap_enum(attrs);
let clap_tokens = basic_clap_app_gen.tokens;
let attrs = basic_clap_app_gen.attrs;

let augment_clap = gen_augment_clap_enum(variants, &attrs);
let from_clap = gen_from_clap_enum();
let from_subcommand = gen_from_subcommand(name, variants, &attrs);
let paw_impl = gen_paw_impl(name);
let paw_impl = gen_paw_impl(&impl_generics, name, &ty_generics, &where_clause);

quote! {
#[allow(unknown_lints)]
Expand All @@ -834,7 +902,7 @@ fn impl_structopt_for_enum(
clippy::cargo
)]
#[deny(clippy::correctness)]
impl ::structopt::StructOpt for #name {
impl #impl_generics ::structopt::StructOpt for #name #ty_generics #where_clause {
#clap_tokens
#from_clap
}
Expand All @@ -853,7 +921,7 @@ fn impl_structopt_for_enum(
)]
#[deny(clippy::correctness)]
#[allow(dead_code, unreachable_code)]
impl ::structopt::StructOptInternal for #name {
impl #impl_generics ::structopt::StructOptInternal for #name #ty_generics #where_clause {
#augment_clap
#from_subcommand
fn is_subcommand() -> bool { true }
Expand Down Expand Up @@ -885,8 +953,8 @@ fn impl_structopt(input: &DeriveInput) -> TokenStream {
Struct(DataStruct {
fields: syn::Fields::Named(ref fields),
..
}) => impl_structopt_for_struct(struct_name, &fields.named, &input.attrs),
Enum(ref e) => impl_structopt_for_enum(struct_name, &e.variants, &input.attrs),
}) => impl_structopt_for_struct(struct_name, &fields.named, &input.attrs, &input.generics),
Enum(ref e) => impl_structopt_for_enum(struct_name, &e.variants, &input.attrs, &input.generics),
_ => abort_call_site!("structopt only supports non-tuple structs and enums"),
}
}
86 changes: 86 additions & 0 deletions tests/generics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@

use structopt::StructOpt;

#[test]
fn generic_struct_flatten() {

#[derive(StructOpt,PartialEq,Debug)]
struct Inner{
pub answer: isize
}

#[derive(StructOpt,PartialEq,Debug)]
struct Outer<T:StructOpt>{
#[structopt(flatten)]
pub inner: T
}

assert_eq!(
Outer{inner: Inner{ answer: 42 }},
Outer::from_iter(&[ "--answer", "42" ])
)
}

#[test]
fn generic_struct_flatten_w_where_clause() {

#[derive(StructOpt,PartialEq,Debug)]
struct Inner{
pub answer: isize
}

#[derive(StructOpt,PartialEq,Debug)]
struct Outer<T> where T:StructOpt {
#[structopt(flatten)]
pub inner: T
}

assert_eq!(
Outer{inner: Inner{ answer: 42 }},
Outer::from_iter(&[ "--answer", "42" ])
)
}

#[test]
fn generic_enum() {

#[derive(StructOpt,PartialEq,Debug)]
struct Inner{
pub answer: isize
}

#[derive(StructOpt,PartialEq,Debug)]
enum GenericEnum<T: StructOpt> {

Start(T),
Stop,
}

assert_eq!(
GenericEnum::Start(Inner{answer: 42}),
GenericEnum::from_iter(&[ "test", "start", "42" ])
)

}

#[test]
fn generic_enum_w_where_clause() {

#[derive(StructOpt,PartialEq,Debug)]
struct Inner{
pub answer: isize
}

#[derive(StructOpt,PartialEq,Debug)]
enum GenericEnum<T> where T: StructOpt {

Start(T),
Stop,
}

assert_eq!(
GenericEnum::Start(Inner{answer: 42}),
GenericEnum::from_iter(&[ "test", "start", "42" ])
)

}

0 comments on commit c174294

Please sign in to comment.