diff --git a/CHANGELOG.md b/CHANGELOG.md index 650afea..4734c9b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# unreleased + +* Add support for [generics in derive](https://github.com/TeXitoi/structopt/issues/128) + # v0.3.21 (2020-11-30) * Fixed [another breakage](https://github.com/TeXitoi/structopt/issues/447) diff --git a/src/lib.rs b/src/lib.rs index fb4ad85..b6e6b73 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,7 @@ //! - [Flattening subcommands](#flattening-subcommands) //! - [Flattening](#flattening) //! - [Custom string parsers](#custom-string-parsers) +//! - [Generics](#generics) //! //! //! @@ -1053,6 +1054,42 @@ //! In the `try_from_*` variants, the function will run twice on valid input: //! once to validate, and once to parse. Hence, make sure the function is //! side-effect-free. +//! +//! ## Generics +//! +//! Generic structs and enums can be used. They require explicit trait bounds +//! on any generic types that will be used by the `StructOpt` derive macro. In +//! some cases, associated types will require additional bounds. See the usage +//! of `FromStr` below for an example of this. +//! +//! ``` +//! # use structopt::StructOpt; +//! use std::{fmt, str::FromStr}; +//! +//! // a struct with single custom argument +//! #[derive(StructOpt)] +//! struct GenericArgs where ::Err: fmt::Display + fmt::Debug { +//! generic_arg_1: String, +//! generic_arg_2: String, +//! custom_arg_1: T +//! } +//! ``` +//! +//! or +//! +//! ``` +//! # use structopt::StructOpt; +//! // a struct with multiple custom arguments in a substructure +//! #[derive(StructOpt)] +//! struct GenericArgs { +//! generic_arg_1: String, +//! generic_arg_2: String, +//! #[structopt(flatten)] +//! custom_args: T +//! } +//! ``` + + // those mains are for a reason #![allow(clippy::needless_doctest_main)] diff --git a/structopt-derive/src/lib.rs b/structopt-derive/src/lib.rs index b818386..cf4dbba 100644 --- a/structopt-derive/src/lib.rs +++ b/structopt-derive/src/lib.rs @@ -559,10 +559,10 @@ fn gen_augment_clap_enum( } } -fn gen_from_clap_enum(name: &Ident) -> TokenStream { +fn gen_from_clap_enum() -> TokenStream { quote! { fn from_clap(matches: &::structopt::clap::ArgMatches) -> Self { - <#name as ::structopt::StructOptInternal>::from_subcommand(matches.subcommand()) + ::from_subcommand(matches.subcommand()) .expect("structopt misuse: You likely tried to #[flatten] a struct \ that contains #[subcommand]. This is forbidden.") } @@ -736,9 +736,9 @@ fn gen_from_subcommand( } #[cfg(feature = "paw")] -fn gen_paw_impl(name: &Ident) -> TokenStream { +fn gen_paw_impl(impl_generics: &ImplGenerics, name: &Ident, ty_generics: &TypeGenerics, where_clause: &TokenStream) -> TokenStream { quote! { - impl ::structopt::paw::ParseArgs for #name { + impl #impl_generics ::structopt::paw::ParseArgs for #name #ty_generics #where_clause { type Error = std::io::Error; fn parse_args() -> std::result::Result { @@ -748,19 +748,109 @@ fn gen_paw_impl(name: &Ident) -> TokenStream { } } #[cfg(not(feature = "paw"))] -fn gen_paw_impl(_: &Ident) -> TokenStream { +fn gen_paw_impl(_: &ImplGenerics, _: &Ident, _: &TypeGenerics, _: &TokenStream) -> TokenStream { TokenStream::new() } +fn split_structopt_generics_for_impl(generics: &Generics) -> (ImplGenerics, TypeGenerics, TokenStream) { + use syn::{ token::Add, TypeParamBound::Trait }; + + fn path_ends_with(path: &Path, ident: &str) -> bool { + path.segments.last().unwrap().ident == ident + } + + fn type_param_bounds_contains(bounds: &Punctuated, ident: &str) -> bool { + for bound in bounds { + if let Trait(bound) = bound { + if path_ends_with(&bound.path, ident) { + return true; + } + } + } + return false; + } + + struct TraitBoundAmendments{ + tokens: TokenStream, + need_where: bool, + need_comma: bool, + } + + impl TraitBoundAmendments { + fn new(where_clause: Option<&WhereClause>) -> Self { + let tokens = TokenStream::new(); + let (need_where,need_comma) = if let Some(where_clause) = where_clause { + if where_clause.predicates.trailing_punct() { + (false, false) + } else { + (false, true) + } + } else { + (true, false) + }; + Self{tokens, need_where, need_comma} + } + + fn add(&mut self, amendment: TokenStream) { + if self.need_where { + self.tokens.extend(quote!{ where }); + self.need_where = false; + } + if self.need_comma { + self.tokens.extend(quote!{ , }); + } + self.tokens.extend(amendment); + self.need_comma = true; + } + + fn into_tokens(self) -> TokenStream { + self.tokens + } + } + + let mut trait_bound_amendments = TraitBoundAmendments::new(generics.where_clause.as_ref()); + + for param in &generics.params { + if let GenericParam::Type(param) = param { + let param_ident = ¶m.ident; + if type_param_bounds_contains(¶m.bounds, "StructOpt") { + trait_bound_amendments.add(quote!{ #param_ident : ::structopt::StructOptInternal }); + } + } + } + + if let Some(where_clause) = &generics.where_clause { + for predicate in &where_clause.predicates { + if let WherePredicate::Type(predicate) = predicate { + let predicate_bounded_ty = &predicate.bounded_ty; + if type_param_bounds_contains(&predicate.bounds, "StructOpt") { + trait_bound_amendments.add(quote!{ #predicate_bounded_ty : ::structopt::StructOptInternal }); + } + } + } + } + + let trait_bound_amendments = trait_bound_amendments.into_tokens(); + + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let where_clause = quote!{ #where_clause #trait_bound_amendments }; + + (impl_generics, ty_generics, where_clause) +} + fn impl_structopt_for_struct( name: &Ident, fields: &Punctuated, attrs: &[Attribute], + generics: &Generics, ) -> TokenStream { + let (impl_generics, ty_generics, where_clause) = split_structopt_generics_for_impl(&generics); + let basic_clap_app_gen = gen_clap_struct(attrs); let augment_clap = gen_augment_clap(fields, &basic_clap_app_gen.attrs); let from_clap = gen_from_clap(name, fields, &basic_clap_app_gen.attrs); - let paw_impl = gen_paw_impl(name); + let paw_impl = gen_paw_impl(&impl_generics, name, &ty_generics, &where_clause); let clap_tokens = basic_clap_app_gen.tokens; quote! { @@ -778,7 +868,7 @@ fn impl_structopt_for_struct( )] #[deny(clippy::correctness)] #[allow(dead_code, unreachable_code)] - impl ::structopt::StructOpt for #name { + impl #impl_generics ::structopt::StructOpt for #name #ty_generics #where_clause { #clap_tokens #from_clap } @@ -797,7 +887,7 @@ fn impl_structopt_for_struct( )] #[deny(clippy::correctness)] #[allow(dead_code, unreachable_code)] - impl ::structopt::StructOptInternal for #name { + impl #impl_generics ::structopt::StructOptInternal for #name #ty_generics #where_clause { #augment_clap fn is_subcommand() -> bool { false } } @@ -810,15 +900,19 @@ fn impl_structopt_for_enum( name: &Ident, variants: &Punctuated, attrs: &[Attribute], + generics: &Generics, ) -> TokenStream { + + let (impl_generics, ty_generics, where_clause) = split_structopt_generics_for_impl(&generics); + let basic_clap_app_gen = gen_clap_enum(attrs); let clap_tokens = basic_clap_app_gen.tokens; let attrs = basic_clap_app_gen.attrs; let augment_clap = gen_augment_clap_enum(variants, &attrs); - let from_clap = gen_from_clap_enum(name); + let from_clap = gen_from_clap_enum(); let from_subcommand = gen_from_subcommand(name, variants, &attrs); - let paw_impl = gen_paw_impl(name); + let paw_impl = gen_paw_impl(&impl_generics, name, &ty_generics, &where_clause); quote! { #[allow(unknown_lints)] @@ -834,7 +928,7 @@ fn impl_structopt_for_enum( clippy::cargo )] #[deny(clippy::correctness)] - impl ::structopt::StructOpt for #name { + impl #impl_generics ::structopt::StructOpt for #name #ty_generics #where_clause { #clap_tokens #from_clap } @@ -853,7 +947,7 @@ fn impl_structopt_for_enum( )] #[deny(clippy::correctness)] #[allow(dead_code, unreachable_code)] - impl ::structopt::StructOptInternal for #name { + impl #impl_generics ::structopt::StructOptInternal for #name #ty_generics #where_clause { #augment_clap #from_subcommand fn is_subcommand() -> bool { true } @@ -885,8 +979,8 @@ fn impl_structopt(input: &DeriveInput) -> TokenStream { Struct(DataStruct { fields: syn::Fields::Named(ref fields), .. - }) => impl_structopt_for_struct(struct_name, &fields.named, &input.attrs), - Enum(ref e) => impl_structopt_for_enum(struct_name, &e.variants, &input.attrs), + }) => impl_structopt_for_struct(struct_name, &fields.named, &input.attrs, &input.generics), + Enum(ref e) => impl_structopt_for_enum(struct_name, &e.variants, &input.attrs, &input.generics), _ => abort_call_site!("structopt only supports non-tuple structs and enums"), } } diff --git a/tests/generics.rs b/tests/generics.rs new file mode 100644 index 0000000..896f98a --- /dev/null +++ b/tests/generics.rs @@ -0,0 +1,137 @@ + +use structopt::StructOpt; + +#[test] +fn generic_struct_flatten() { + + #[derive(StructOpt,PartialEq,Debug)] + struct Inner{ + pub answer: isize + } + + #[derive(StructOpt,PartialEq,Debug)] + struct Outer{ + #[structopt(flatten)] + pub inner: T + } + + assert_eq!( + Outer{inner: Inner{ answer: 42 }}, + Outer::from_iter(&[ "--answer", "42" ]) + ) +} + +#[test] +fn generic_struct_flatten_w_where_clause() { + + #[derive(StructOpt,PartialEq,Debug)] + struct Inner{ + pub answer: isize + } + + #[derive(StructOpt,PartialEq,Debug)] + struct Outer where T:StructOpt { + #[structopt(flatten)] + pub inner: T + } + + assert_eq!( + Outer{inner: Inner{ answer: 42 }}, + Outer::from_iter(&[ "--answer", "42" ]) + ) +} + +#[test] +fn generic_enum() { + + #[derive(StructOpt,PartialEq,Debug)] + struct Inner{ + pub answer: isize + } + + #[derive(StructOpt,PartialEq,Debug)] + enum GenericEnum { + + Start(T), + Stop, + } + + assert_eq!( + GenericEnum::Start(Inner{answer: 42}), + GenericEnum::from_iter(&[ "test", "start", "42" ]) + ) + +} + +#[test] +fn generic_enum_w_where_clause() { + + #[derive(StructOpt,PartialEq,Debug)] + struct Inner{ + pub answer: isize + } + + #[derive(StructOpt,PartialEq,Debug)] + enum GenericEnum where T: StructOpt { + + Start(T), + Stop, + } + + assert_eq!( + GenericEnum::Start(Inner{answer: 42}), + GenericEnum::from_iter(&[ "test", "start", "42" ]) + ) + +} + +#[test] +fn generic_w_fromstr_trait_bound() { + + use std::{fmt, str::FromStr}; + + #[derive(StructOpt,PartialEq,Debug)] + struct Opt where T:FromStr, ::Err: fmt::Debug + fmt::Display + { + answer: T + } + + assert_eq!( + Opt::{answer:42}, + Opt::::from_iter([& "--answer", "42" ]) + ) +} + +#[test] +fn generic_wo_trait_bound() { + + use std::time::Duration; + + #[derive(StructOpt,PartialEq,Debug)] + struct Opt { + answer: isize, + #[structopt(skip)] + took: Option + } + + assert_eq!( + Opt::{answer:42,took:None}, + Opt::::from_iter([& "--answer", "42" ]) + ) +} + +#[test] +fn generic_where_clause_w_trailing_comma() { + + use std::{fmt, str::FromStr}; + + #[derive(StructOpt,PartialEq,Debug)] + struct Opt where T:FromStr, ::Err: fmt::Debug + fmt::Display { + pub answer: T + } + + assert_eq!( + Opt::{answer:42}, + Opt::::from_iter(&[ "--answer", "42" ]) + ) +}