Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add basic support for generic structs and enums #483

Merged
merged 9 commits into from Jul 1, 2021
4 changes: 4 additions & 0 deletions CHANGELOG.md
@@ -1,3 +1,7 @@
# unreleased

* Add support for [generics in derive](https://github.com/TeXitoi/structopt/issues/128)

# v0.3.21 (2020-11-30)

* Fixed [another breakage](https://github.com/TeXitoi/structopt/issues/447)
Expand Down
35 changes: 35 additions & 0 deletions src/lib.rs
Expand Up @@ -52,6 +52,7 @@
//! - [Flattening subcommands](#flattening-subcommands)
//! - [Flattening](#flattening)
//! - [Custom string parsers](#custom-string-parsers)
//! - [Generics](#generics)
//!
//!
//!
Expand Down Expand Up @@ -1053,6 +1054,40 @@
//! In the `try_from_*` variants, the function will run twice on valid input:
//! once to validate, and once to parse. Hence, make sure the function is
//! side-effect-free.
//!
//! ## Generics
//!
//! Generic structs and enums can be used. They require explicit trait bounds
//! any generic types that will be used by the `StructOpt` derive macro.
//!
//! ```
//! # use structopt::StructOpt;
//! use std::str::FromStr;
//!
//! // a struct with single custom argument
//! #[derive(StructOpt)]
//! struct GenericArgs<T:FromStr> {
//! generic_arg_1: String,
//! generic_arg_2: String,
//! custom_arg_1: T
//! }
//! ```
//!
//! or
//!
//! ```
//! # use structopt::StructOpt;
//! // a struct with multiple custom arguments in a substructure
//! #[derive(StructOpt)]
//! struct GenericArgs<T:StructOpt> {
//! generic_arg_1: String,
//! generic_arg_2: String,
//! #[structopt(flatten)]
//! custom_args: T
//! }
//! ```



// those mains are for a reason
#![allow(clippy::needless_doctest_main)]
Expand Down
128 changes: 114 additions & 14 deletions structopt-derive/src/lib.rs
Expand Up @@ -559,10 +559,10 @@ fn gen_augment_clap_enum(
}
}

fn gen_from_clap_enum(name: &Ident) -> TokenStream {
fn gen_from_clap_enum() -> TokenStream {
quote! {
fn from_clap(matches: &::structopt::clap::ArgMatches) -> Self {
<#name as ::structopt::StructOptInternal>::from_subcommand(matches.subcommand())
<Self as ::structopt::StructOptInternal>::from_subcommand(matches.subcommand())
.expect("structopt misuse: You likely tried to #[flatten] a struct \
that contains #[subcommand]. This is forbidden.")
}
Expand Down Expand Up @@ -736,9 +736,9 @@ fn gen_from_subcommand(
}

#[cfg(feature = "paw")]
fn gen_paw_impl(name: &Ident) -> TokenStream {
fn gen_paw_impl(impl_generics: &ImplGenerics, name: &Ident, ty_generics: &TypeGenerics, where_clause: &TokenStream) -> TokenStream {
quote! {
impl ::structopt::paw::ParseArgs for #name {
impl #impl_generics ::structopt::paw::ParseArgs for #name #ty_generics #where_clause {
type Error = std::io::Error;

fn parse_args() -> std::result::Result<Self, Self::Error> {
Expand All @@ -748,19 +748,115 @@ fn gen_paw_impl(name: &Ident) -> TokenStream {
}
}
#[cfg(not(feature = "paw"))]
fn gen_paw_impl(_: &Ident) -> TokenStream {
fn gen_paw_impl(_: &ImplGenerics, _: &Ident, _: &TypeGenerics, _: &TokenStream) -> TokenStream {
TokenStream::new()
}

fn split_structopt_generics_for_impl(generics: &Generics) -> (ImplGenerics, TypeGenerics, TokenStream) {
use syn::{ token::Add, TypeParamBound::Trait };

fn path_ends_with(path: &Path, ident: &str) -> bool {
path.segments.last().unwrap().ident == ident
}

fn type_param_bounds_contains(bounds: &Punctuated<TypeParamBound, Add>, ident: &str) -> bool {
for bound in bounds {
if let Trait(bound) = bound {
if path_ends_with(&bound.path, ident) {
return true;
}
}
}
return false;
}

struct TraitBoundAmendments{
tokens: TokenStream,
need_where: bool,
need_comma: bool,
}

impl TraitBoundAmendments {
fn new(where_clause: Option<&WhereClause>) -> Self {
let tokens = TokenStream::new();
let (need_where,need_comma) = if let Some(where_clause) = where_clause {
if where_clause.predicates.trailing_punct() {
(false, false)
} else {
(false, true)
}
} else {
(true, false)
};
Self{tokens, need_where, need_comma}
}

fn add(&mut self, amendment: TokenStream) {
if self.need_where {
self.tokens.extend(quote!{ where });
self.need_where = false;
}
if self.need_comma {
self.tokens.extend(quote!{ , });
}
self.tokens.extend(amendment);
self.need_comma = true;
}

fn into_tokens(self) -> TokenStream {
self.tokens
}
}

let mut trait_bound_amendments = TraitBoundAmendments::new(generics.where_clause.as_ref());

for param in &generics.params {
if let GenericParam::Type(param) = param {
let param_ident = &param.ident;
if type_param_bounds_contains(&param.bounds, "StructOpt") {
trait_bound_amendments.add(quote!{ #param_ident : ::structopt::StructOptInternal });
}
if type_param_bounds_contains(&param.bounds, "FromStr") {
trait_bound_amendments.add(quote!{ < #param_ident as ::std::str::FromStr>::Err : ::std::fmt::Display + ::std::fmt::Debug });
}
}
}

if let Some(where_clause) = &generics.where_clause {
for predicate in &where_clause.predicates {
if let WherePredicate::Type(predicate) = predicate {
let predicate_bounded_ty = &predicate.bounded_ty;
if type_param_bounds_contains(&predicate.bounds, "StructOpt") {
trait_bound_amendments.add(quote!{ #predicate_bounded_ty : ::structopt::StructOptInternal });
}
if type_param_bounds_contains(&predicate.bounds, "FromStr") {
trait_bound_amendments.add(quote!{ < #predicate_bounded_ty as ::std::str::FromStr>::Err : ::std::fmt::Display + ::std::fmt::Debug });
njeffords marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}

let trait_bound_amendments = trait_bound_amendments.into_tokens();

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let where_clause = quote!{ #where_clause #trait_bound_amendments };

(impl_generics, ty_generics, where_clause)
}

fn impl_structopt_for_struct(
name: &Ident,
fields: &Punctuated<Field, Comma>,
attrs: &[Attribute],
generics: &Generics,
) -> TokenStream {
let (impl_generics, ty_generics, where_clause) = split_structopt_generics_for_impl(&generics);

let basic_clap_app_gen = gen_clap_struct(attrs);
let augment_clap = gen_augment_clap(fields, &basic_clap_app_gen.attrs);
let from_clap = gen_from_clap(name, fields, &basic_clap_app_gen.attrs);
let paw_impl = gen_paw_impl(name);
let paw_impl = gen_paw_impl(&impl_generics, name, &ty_generics, &where_clause);

let clap_tokens = basic_clap_app_gen.tokens;
quote! {
Expand All @@ -778,7 +874,7 @@ fn impl_structopt_for_struct(
)]
#[deny(clippy::correctness)]
#[allow(dead_code, unreachable_code)]
impl ::structopt::StructOpt for #name {
impl #impl_generics ::structopt::StructOpt for #name #ty_generics #where_clause {
#clap_tokens
#from_clap
}
Expand All @@ -797,7 +893,7 @@ fn impl_structopt_for_struct(
)]
#[deny(clippy::correctness)]
#[allow(dead_code, unreachable_code)]
impl ::structopt::StructOptInternal for #name {
impl #impl_generics ::structopt::StructOptInternal for #name #ty_generics #where_clause {
#augment_clap
fn is_subcommand() -> bool { false }
}
Expand All @@ -810,15 +906,19 @@ fn impl_structopt_for_enum(
name: &Ident,
variants: &Punctuated<Variant, Comma>,
attrs: &[Attribute],
generics: &Generics,
) -> TokenStream {

let (impl_generics, ty_generics, where_clause) = split_structopt_generics_for_impl(&generics);

let basic_clap_app_gen = gen_clap_enum(attrs);
let clap_tokens = basic_clap_app_gen.tokens;
let attrs = basic_clap_app_gen.attrs;

let augment_clap = gen_augment_clap_enum(variants, &attrs);
let from_clap = gen_from_clap_enum(name);
let from_clap = gen_from_clap_enum();
let from_subcommand = gen_from_subcommand(name, variants, &attrs);
let paw_impl = gen_paw_impl(name);
let paw_impl = gen_paw_impl(&impl_generics, name, &ty_generics, &where_clause);

quote! {
#[allow(unknown_lints)]
Expand All @@ -834,7 +934,7 @@ fn impl_structopt_for_enum(
clippy::cargo
)]
#[deny(clippy::correctness)]
impl ::structopt::StructOpt for #name {
impl #impl_generics ::structopt::StructOpt for #name #ty_generics #where_clause {
#clap_tokens
#from_clap
}
Expand All @@ -853,7 +953,7 @@ fn impl_structopt_for_enum(
)]
#[deny(clippy::correctness)]
#[allow(dead_code, unreachable_code)]
impl ::structopt::StructOptInternal for #name {
impl #impl_generics ::structopt::StructOptInternal for #name #ty_generics #where_clause {
#augment_clap
#from_subcommand
fn is_subcommand() -> bool { true }
Expand Down Expand Up @@ -885,8 +985,8 @@ fn impl_structopt(input: &DeriveInput) -> TokenStream {
Struct(DataStruct {
fields: syn::Fields::Named(ref fields),
..
}) => impl_structopt_for_struct(struct_name, &fields.named, &input.attrs),
Enum(ref e) => impl_structopt_for_enum(struct_name, &e.variants, &input.attrs),
}) => impl_structopt_for_struct(struct_name, &fields.named, &input.attrs, &input.generics),
Enum(ref e) => impl_structopt_for_enum(struct_name, &e.variants, &input.attrs, &input.generics),
_ => abort_call_site!("structopt only supports non-tuple structs and enums"),
}
}