diff --git a/resources/test/fixtures/pyupgrade/UP028_1.py b/resources/test/fixtures/pyupgrade/UP028_1.py index 1ef3e76858f2e..db49d6866882a 100644 --- a/resources/test/fixtures/pyupgrade/UP028_1.py +++ b/resources/test/fixtures/pyupgrade/UP028_1.py @@ -111,3 +111,13 @@ def f(): class C: def __init__(self): print(x) + + +def f(): + for x in y: + yield x, x + 1 + + +def f(): + for x, y in z: + yield x, y, x + y diff --git a/src/pyupgrade/plugins/rewrite_yield_from.rs b/src/pyupgrade/plugins/rewrite_yield_from.rs index ee8a5f1aaee0f..b19150d536dc8 100644 --- a/src/pyupgrade/plugins/rewrite_yield_from.rs +++ b/src/pyupgrade/plugins/rewrite_yield_from.rs @@ -8,11 +8,25 @@ use crate::autofix::Fix; use crate::checkers::ast::Checker; use crate::registry::{Check, CheckKind}; +/// Return `true` if the two expressions are equivalent, and consistent solely +/// of tuples and names. +fn is_same_expr(a: &Expr, b: &Expr) -> bool { + match (&a.node, &b.node) { + (ExprKind::Name { id: a, .. }, ExprKind::Name { id: b, .. }) => a == b, + (ExprKind::Tuple { elts: a, .. }, ExprKind::Tuple { elts: b, .. }) => { + a.len() == b.len() && a.iter().zip(b).all(|(a, b)| is_same_expr(a, b)) + } + _ => false, + } +} + +/// Collect all named variables in an expression consisting solely of tuples and +/// names. fn collect_names(expr: &Expr) -> Vec<&str> { match &expr.node { ExprKind::Name { id, .. } => vec![id], ExprKind::Tuple { elts, .. } => elts.iter().flat_map(collect_names).collect(), - _ => vec![], + _ => unreachable!("Expected: ExprKind::Name | ExprKind::Tuple"), } } @@ -51,13 +65,12 @@ impl<'a> Visitor<'a> for YieldFromVisitor<'a> { let body = &body[0]; if let StmtKind::Expr { value } = &body.node { if let ExprKind::Yield { value: Some(value) } = &value.node { - let names = collect_names(target); - if names == collect_names(value) { + if is_same_expr(target, value) { self.yields.push(YieldFrom { stmt, body, iter, - names, + names: collect_names(target), }); } }