diff --git a/clap_derive/src/derives/args.rs b/clap_derive/src/derives/args.rs index 21786505a82..f49883ce3bb 100644 --- a/clap_derive/src/derives/args.rs +++ b/clap_derive/src/derives/args.rs @@ -202,8 +202,12 @@ pub fn gen_augment( #implicit_methods; }) } - Kind::Flatten => { - let ty = &field.ty; + Kind::Flatten(ty) => { + let inner_type = match (**ty, sub_type(&field.ty)) { + (Ty::Option, Some(sub_type)) => sub_type, + _ => &field.ty, + }; + let next_help_heading = item.next_help_heading(); let next_display_order = item.next_display_order(); if override_required { @@ -211,14 +215,14 @@ pub fn gen_augment( let #app_var = #app_var #next_help_heading #next_display_order; - let #app_var = <#ty as clap::Args>::augment_args_for_update(#app_var); + let #app_var = <#inner_type as clap::Args>::augment_args_for_update(#app_var); }) } else { Some(quote_spanned! { kind.span()=> let #app_var = #app_var #next_help_heading #next_display_order; - let #app_var = <#ty as clap::Args>::augment_args(#app_var); + let #app_var = <#inner_type as clap::Args>::augment_args(#app_var); }) } } @@ -350,7 +354,7 @@ pub fn gen_augment( .iter() .filter(|(_field, item)| { let kind = item.kind(); - matches!(*kind, Kind::Flatten) + matches!(*kind, Kind::Flatten(_)) }) .count(); if 0 < possible_group_members_len { @@ -410,7 +414,6 @@ pub fn gen_constructor(fields: &[(&Field, Item)]) -> TokenStream { } } }, - Ty::Vec | Ty::Other => { quote_spanned! { kind.span()=> #field_name: { @@ -418,6 +421,7 @@ pub fn gen_constructor(fields: &[(&Field, Item)]) -> TokenStream { } } }, + Ty::Vec | Ty::OptionOption | Ty::OptionVec => { abort!( @@ -429,8 +433,42 @@ pub fn gen_constructor(fields: &[(&Field, Item)]) -> TokenStream { } } - Kind::Flatten => quote_spanned! { kind.span()=> - #field_name: clap::FromArgMatches::from_arg_matches_mut(#arg_matches)? + Kind::Flatten(ty) => { + let inner_type = match (**ty, sub_type(&field.ty)) { + (Ty::Option, Some(sub_type)) => sub_type, + _ => &field.ty, + }; + match **ty { + Ty::Other => { + quote_spanned! { kind.span()=> + #field_name: <#inner_type as clap::FromArgMatches>::from_arg_matches_mut(#arg_matches)? + } + }, + Ty::Option => { + quote_spanned! { kind.span()=> + #field_name: { + let group_id = <#inner_type as clap::Args>::group_id() + .expect("`#[arg(flatten)]`ed field type implements `Args::group_id`"); + if #arg_matches.contains_id(group_id.as_str()) { + Some( + <#inner_type as clap::FromArgMatches>::from_arg_matches_mut(#arg_matches)? + ) + } else { + None + } + } + } + }, + Ty::Vec | + Ty::OptionOption | + Ty::OptionVec => { + abort!( + ty.span(), + "{} types are not supported for flatten", + ty.as_str() + ); + } + } }, Kind::Skip(val, _) => match val { @@ -506,9 +544,36 @@ pub fn gen_updater(fields: &[(&Field, Item)], use_self: bool) -> TokenStream { } } - Kind::Flatten => quote_spanned! { kind.span()=> { - #access - clap::FromArgMatches::update_from_arg_matches_mut(#field_name, #arg_matches)?; + Kind::Flatten(ty) => { + let inner_type = match (**ty, sub_type(&field.ty)) { + (Ty::Option, Some(sub_type)) => sub_type, + _ => &field.ty, + }; + + let updater = quote_spanned! { ty.span()=> + <#inner_type as clap::FromArgMatches>::update_from_arg_matches_mut(#field_name, #arg_matches)?; + }; + + let updater = match **ty { + Ty::Option => quote_spanned! { kind.span()=> + if let Some(#field_name) = #field_name.as_mut() { + #updater + } else { + *#field_name = Some(<#inner_type as clap::FromArgMatches>::from_arg_matches_mut( + #arg_matches + )?); + } + }, + _ => quote_spanned! { kind.span()=> + #updater + }, + }; + + quote_spanned! { kind.span()=> + { + #access + #updater + } } }, diff --git a/clap_derive/src/derives/subcommand.rs b/clap_derive/src/derives/subcommand.rs index 1989e82fded..b437e42c076 100644 --- a/clap_derive/src/derives/subcommand.rs +++ b/clap_derive/src/derives/subcommand.rs @@ -173,7 +173,7 @@ fn gen_augment( Some(subcommand) } - Kind::Flatten => match variant.fields { + Kind::Flatten(_) => match variant.fields { Unnamed(FieldsUnnamed { ref unnamed, .. }) if unnamed.len() == 1 => { let ty = &unnamed[0]; let deprecations = if !override_required { @@ -363,7 +363,7 @@ fn gen_has_subcommand(variants: &[(&Variant, Item)]) -> TokenStream { }) .partition(|(_, item)| { let kind = item.kind(); - matches!(&*kind, Kind::Flatten) + matches!(&*kind, Kind::Flatten(_)) }); let subcommands = variants.iter().map(|(_variant, item)| { @@ -464,7 +464,7 @@ fn gen_from_arg_matches(variants: &[(&Variant, Item)]) -> TokenStream { }) .partition(|(_, item)| { let kind = item.kind(); - matches!(&*kind, Kind::Flatten) + matches!(&*kind, Kind::Flatten(_)) }); let subcommands = variants.iter().map(|(variant, item)| { @@ -571,7 +571,7 @@ fn gen_update_from_arg_matches(variants: &[(&Variant, Item)]) -> TokenStream { }) .partition(|(_, item)| { let kind = item.kind(); - matches!(&*kind, Kind::Flatten) + matches!(&*kind, Kind::Flatten(_)) }); let subcommands = variants.iter().map(|(variant, item)| { diff --git a/clap_derive/src/item.rs b/clap_derive/src/item.rs index 68af9d62836..b9f64da6499 100644 --- a/clap_derive/src/item.rs +++ b/clap_derive/src/item.rs @@ -148,7 +148,7 @@ impl Item { } match &*res.kind { - Kind::Flatten => { + Kind::Flatten(_) => { if res.has_explicit_methods() { abort!( res.kind.span(), @@ -224,7 +224,7 @@ impl Item { } match &*res.kind { - Kind::Flatten => { + Kind::Flatten(_) => { if res.has_explicit_methods() { abort!( res.kind.span(), @@ -383,7 +383,12 @@ impl Item { let expr = attr.value_or_abort(); abort!(expr, "attribute `{}` does not accept a value", attr.name); } - let kind = Sp::new(Kind::Flatten, attr.name.clone().span()); + let ty = self + .kind() + .ty() + .cloned() + .unwrap_or_else(|| Sp::new(Ty::Other, self.kind.span())); + let kind = Sp::new(Kind::Flatten(ty), attr.name.clone().span()); Some(kind) } Some(MagicAttrName::Skip) if actual_attr_kind != AttrKind::Group => { @@ -902,10 +907,10 @@ impl Item { match (self.kind.get(), kind.get()) { (Kind::Arg(_), Kind::FromGlobal(_)) | (Kind::Arg(_), Kind::Subcommand(_)) - | (Kind::Arg(_), Kind::Flatten) + | (Kind::Arg(_), Kind::Flatten(_)) | (Kind::Arg(_), Kind::Skip(_, _)) | (Kind::Command(_), Kind::Subcommand(_)) - | (Kind::Command(_), Kind::Flatten) + | (Kind::Command(_), Kind::Flatten(_)) | (Kind::Command(_), Kind::Skip(_, _)) | (Kind::Command(_), Kind::ExternalSubcommand) | (Kind::Value, Kind::Skip(_, _)) => { @@ -1142,7 +1147,7 @@ pub enum Kind { Value, FromGlobal(Sp), Subcommand(Sp), - Flatten, + Flatten(Sp), Skip(Option, AttrKind), ExternalSubcommand, } @@ -1155,7 +1160,7 @@ impl Kind { Self::Value => "value", Self::FromGlobal(_) => "from_global", Self::Subcommand(_) => "subcommand", - Self::Flatten => "flatten", + Self::Flatten(_) => "flatten", Self::Skip(_, _) => "skip", Self::ExternalSubcommand => "external_subcommand", } @@ -1168,7 +1173,7 @@ impl Kind { Self::Value => AttrKind::Value, Self::FromGlobal(_) => AttrKind::Arg, Self::Subcommand(_) => AttrKind::Command, - Self::Flatten => AttrKind::Command, + Self::Flatten(_) => AttrKind::Command, Self::Skip(_, kind) => *kind, Self::ExternalSubcommand => AttrKind::Command, } @@ -1176,10 +1181,12 @@ impl Kind { pub fn ty(&self) -> Option<&Sp> { match self { - Self::Arg(ty) | Self::Command(ty) | Self::FromGlobal(ty) | Self::Subcommand(ty) => { - Some(ty) - } - Self::Value | Self::Flatten | Self::Skip(_, _) | Self::ExternalSubcommand => None, + Self::Arg(ty) + | Self::Command(ty) + | Self::Flatten(ty) + | Self::FromGlobal(ty) + | Self::Subcommand(ty) => Some(ty), + Self::Value | Self::Skip(_, _) | Self::ExternalSubcommand => None, } } } diff --git a/tests/derive/flatten.rs b/tests/derive/flatten.rs index d4330f02543..7ef6ad86cd8 100644 --- a/tests/derive/flatten.rs +++ b/tests/derive/flatten.rs @@ -255,3 +255,46 @@ fn docstrings_ordering_with_multiple_clap_partial() { assert!(short_help.contains("This is the docstring for Flattened")); } + +#[test] +fn optional_flatten() { + #[derive(Parser, Debug, PartialEq, Eq)] + struct Opt { + #[command(flatten)] + source: Option, + } + + #[derive(clap::Args, Debug, PartialEq, Eq)] + struct Source { + crates: Vec, + #[arg(long)] + path: Option, + #[arg(long)] + git: Option, + } + + assert_eq!( + Opt { source: None }, + Opt::try_parse_from(&["test"]).unwrap() + ); + assert_eq!( + Opt { + source: Some(Source { + crates: vec!["serde".to_owned()], + path: None, + git: None, + }), + }, + Opt::try_parse_from(&["test", "serde"]).unwrap() + ); + assert_eq!( + Opt { + source: Some(Source { + crates: Vec::new(), + path: Some("./".into()), + git: None, + }), + }, + Opt::try_parse_from(&["test", "--path=./"]).unwrap() + ); +}