forked from Peternator7/strum
/
enum_iter.rs
133 lines (113 loc) · 4.16 KB
/
enum_iter.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use proc_macro2::TokenStream;
use syn;
use helpers::{extract_meta, is_disabled};
pub fn enum_iter_inner(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let gen = &ast.generics;
let (impl_generics, ty_generics, where_clause) = gen.split_for_impl();
let vis = &ast.vis;
if gen.lifetimes().count() > 0 {
panic!(
"Enum Iterator isn't supported on Enums with lifetimes. The resulting enums would \
be unbounded."
);
}
let phantom_data = if gen.type_params().count() > 0 {
let g = gen.type_params().map(|param| ¶m.ident);
quote! { < ( #(#g),* ) > }
} else {
quote! { < () > }
};
let variants = match ast.data {
syn::Data::Enum(ref v) => &v.variants,
_ => panic!("EnumIter only works on Enums"),
};
let mut arms = Vec::new();
let enabled = variants
.iter()
.filter(|variant| !is_disabled(&extract_meta(&variant.attrs)));
for (idx, variant) in enabled.enumerate() {
use syn::Fields::*;
let ident = &variant.ident;
let params = match variant.fields {
Unit => quote! {},
Unnamed(ref fields) => {
let defaults = ::std::iter::repeat(quote!(::std::default::Default::default()))
.take(fields.unnamed.len());
quote! { (#(#defaults),*) }
}
Named(ref fields) => {
let fields = fields
.named
.iter()
.map(|field| field.ident.as_ref().unwrap());
quote! { {#(#fields: ::std::default::Default::default()),*} }
}
};
arms.push(quote! {#idx => ::std::option::Option::Some(#name::#ident #params)});
}
let variant_count = arms.len();
arms.push(quote! { _ => ::std::option::Option::None });
let iter_name = syn::parse_str::<syn::Ident>(&format!("{}Iter", name)).unwrap();
quote! {
#[allow(missing_docs)]
#vis struct #iter_name #ty_generics {
idx: usize,
back_idx: usize,
marker: ::std::marker::PhantomData #phantom_data,
}
impl #impl_generics #iter_name #ty_generics #where_clause {
fn get(&self, idx: usize) -> Option<#name #ty_generics> {
match idx {
#(#arms),*
}
}
}
impl #impl_generics ::strum::IntoEnumIterator for #name #ty_generics #where_clause {
type Iterator = #iter_name #ty_generics;
fn iter() -> #iter_name #ty_generics {
#iter_name {
idx: 0,
back_idx: 0,
marker: ::std::marker::PhantomData,
}
}
}
impl #impl_generics Iterator for #iter_name #ty_generics #where_clause {
type Item = #name #ty_generics;
fn next(&mut self) -> Option<#name #ty_generics> {
let output = self.get(self.idx);
self.idx += 1;
output
}
fn size_hint(&self) -> (usize, Option<usize>) {
let t = #variant_count - self.idx;
(t, Some(t))
}
}
impl #impl_generics ExactSizeIterator for #iter_name #ty_generics #where_clause {
fn len(&self) -> usize {
self.size_hint().0
}
}
impl #impl_generics DoubleEndedIterator for #iter_name #ty_generics #where_clause {
fn next_back(&mut self) -> Option<#name #ty_generics> {
if self.back_idx >= #variant_count {
return None
}
let output = self.get(#variant_count - self.back_idx - 1);
self.back_idx += 1;
output
}
}
impl #impl_generics Clone for #iter_name #ty_generics #where_clause {
fn clone(&self) -> #iter_name #ty_generics {
#iter_name {
idx: self.idx,
back_idx: self.back_idx,
marker: self.marker.clone(),
}
}
}
}
}