diff --git a/strum_macros/src/macros/enum_iter.rs b/strum_macros/src/macros/enum_iter.rs index a62bfd82..4d589859 100644 --- a/strum_macros/src/macros/enum_iter.rs +++ b/strum_macros/src/macros/enum_iter.rs @@ -62,14 +62,24 @@ pub fn enum_iter_inner(ast: &syn::DeriveInput) -> TokenStream { #[allow(missing_docs)] #vis struct #iter_name #ty_generics { idx: usize, + back_idx: usize, marker: ::std::marker::PhantomData #phantom_data, } + impl #impl_generics #iter_name #ty_generics #where_clause { + fn get(&self, idx: usize) -> Option<#name #ty_generics> { + match idx { + #(#arms),* + } + } + } + impl #impl_generics ::strum::IntoEnumIterator for #name #ty_generics #where_clause { type Iterator = #iter_name #ty_generics; fn iter() -> #iter_name #ty_generics { #iter_name { - idx:0, + idx: 0, + back_idx: 0, marker: ::std::marker::PhantomData, } } @@ -78,21 +88,24 @@ pub fn enum_iter_inner(ast: &syn::DeriveInput) -> TokenStream { impl #impl_generics Iterator for #iter_name #ty_generics #where_clause { type Item = #name #ty_generics; - fn next(&mut self) -> Option<#name #ty_generics> { - let output = match self.idx { - #(#arms),* - }; - - if self.idx < #variant_count { - self.idx += 1; - } - output + fn next(&mut self) -> Option { + self.nth(0) } - + fn size_hint(&self) -> (usize, Option) { - let t = if self.idx >= #variant_count { 0 } else { #variant_count - self.idx }; + let t = if self.idx + self.back_idx >= #variant_count { 0 } else { #variant_count - self.idx - self.back_idx }; (t, Some(t)) } + + fn nth(&mut self, n: usize) -> Option { + self.idx += n + 1; + + if self.idx + self.back_idx > #variant_count { + None + } else { + self.get(self.idx - 1) + } + } } impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause { @@ -101,10 +114,27 @@ pub fn enum_iter_inner(ast: &syn::DeriveInput) -> TokenStream { } } + impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause { + fn next_back(&mut self) -> Option { + self.nth_back(0) + } + + fn nth_back(&mut self, n: usize) -> Option { + self.back_idx += n + 1; + + if self.idx + self.back_idx > #variant_count { + None + } else { + self.get(#variant_count - self.back_idx) + } + } + } + impl #impl_generics Clone for #iter_name #ty_generics #where_clause { fn clone(&self) -> #iter_name #ty_generics { #iter_name { idx: self.idx, + back_idx: self.back_idx, marker: self.marker.clone(), } } diff --git a/strum_tests/tests/enum_iter.rs b/strum_tests/tests/enum_iter.rs index b0813ef9..6e1975e0 100644 --- a/strum_tests/tests/enum_iter.rs +++ b/strum_tests/tests/enum_iter.rs @@ -68,6 +68,21 @@ fn len_test() { assert_eq!(0, i.size_hint().1.unwrap()); } +#[test] +fn double_ended_len_test() { + let mut i = Complicated::<(), ()>::iter(); + assert_eq!(3, i.len()); + i.next_back(); + + assert_eq!(2, i.len()); + i.next(); + + assert_eq!(1, i.len()); + i.next_back(); + + assert_eq!(0, i.len()); +} + #[test] fn clone_test() { let mut i = Week::iter(); @@ -103,3 +118,60 @@ fn cycle_test() { ]; assert_eq!(expected, results); } + +#[test] +fn reverse_test() { + let results = Week::iter().rev().collect::>(); + let expected = vec![ + Week::Saturday, + Week::Friday, + Week::Thursday, + Week::Wednesday, + Week::Tuesday, + Week::Monday, + Week::Sunday, + ]; + assert_eq!(expected, results); +} + +#[test] +fn take_from_both_sides_test() { + let mut iter = Week::iter(); + + assert_eq!(Some(Week::Sunday), iter.next()); + assert_eq!(Some(Week::Saturday), iter.next_back()); + assert_eq!(Some(Week::Friday), iter.next_back()); + assert_eq!(Some(Week::Monday), iter.next()); + assert_eq!(Some(Week::Tuesday), iter.next()); + assert_eq!(Some(Week::Wednesday), iter.next()); + assert_eq!(Some(Week::Thursday), iter.next_back()); + assert_eq!(None, iter.next()); + assert_eq!(None, iter.next_back()); +} + +#[test] +fn take_from_both_sides_test2() { + let mut iter = Week::iter(); + + assert_eq!(Some(Week::Sunday), iter.next()); + assert_eq!(Some(Week::Saturday), iter.next_back()); + assert_eq!(Some(Week::Friday), iter.next_back()); + assert_eq!(Some(Week::Monday), iter.next()); + assert_eq!(Some(Week::Tuesday), iter.next()); + assert_eq!(Some(Week::Wednesday), iter.next()); + assert_eq!(Some(Week::Thursday), iter.next()); + assert_eq!(None, iter.next_back()); + assert_eq!(None, iter.next()); +} + +#[test] +fn take_nth_test() { + let mut iter = Week::iter(); + + assert_eq!(Some(Week::Tuesday), iter.nth(2)); + assert_eq!(Some(Week::Saturday), iter.nth_back(0)); + assert_eq!(Some(Week::Thursday), iter.nth_back(1)); + assert_eq!(None, iter.nth(1)); + assert_eq!(None, iter.next()); + assert_eq!(None, iter.next_back()); +} \ No newline at end of file