diff --git a/src/ast/helpers.rs b/src/ast/helpers.rs index 0c16f4bccadc4..4c8ef664bdebf 100644 --- a/src/ast/helpers.rs +++ b/src/ast/helpers.rs @@ -213,23 +213,6 @@ pub fn is_constant_non_singleton(expr: &Expr) -> bool { is_constant(expr) && !is_singleton(expr) } -/// Return `true` if an `Expr` is not a reference to a variable (or something -/// that could resolve to a variable, like a function call). -pub fn is_non_variable(expr: &Expr) -> bool { - matches!( - expr.node, - ExprKind::Constant { .. } - | ExprKind::Tuple { .. } - | ExprKind::List { .. } - | ExprKind::Set { .. } - | ExprKind::Dict { .. } - | ExprKind::SetComp { .. } - | ExprKind::ListComp { .. } - | ExprKind::DictComp { .. } - | ExprKind::GeneratorExp { .. } - ) -} - /// Return the `Keyword` with the given name, if it's present in the list of /// `Keyword` arguments. pub fn find_keyword<'a>(keywords: &'a [Keyword], keyword_name: &str) -> Option<&'a Keyword> { diff --git a/src/checkers/ast.rs b/src/checkers/ast.rs index 86f88b9d4e2f5..59bf2bd4d8112 100644 --- a/src/checkers/ast.rs +++ b/src/checkers/ast.rs @@ -191,13 +191,18 @@ impl<'a> Checker<'a> { && match_call_path(call_path, "typing_extensions", target, &self.from_imports)) } - /// Return `true` if `member` is bound as a builtin. - pub fn is_builtin(&self, member: &str) -> bool { + /// Return the current `Binding` for a given `name`. + pub fn find_binding(&self, member: &str) -> Option<&Binding> { self.current_scopes() .find_map(|scope| scope.values.get(member)) - .map_or(false, |index| { - matches!(self.bindings[*index].kind, BindingKind::Builtin) - }) + .map(|index| &self.bindings[*index]) + } + + /// Return `true` if `member` is bound as a builtin. + pub fn is_builtin(&self, member: &str) -> bool { + self.find_binding(member).map_or(false, |binding| { + matches!(binding.kind, BindingKind::Builtin) + }) } /// Return `true` if a `CheckCode` is disabled by a `noqa` directive. @@ -1727,10 +1732,30 @@ where } } // Avoid flagging on non-DataFrames (e.g., `{"a": 1}.values`). - if helpers::is_non_variable(value) { - continue; + if pandas_vet::helpers::is_dataframe_candidate(value) { + // If the target is a named variable, avoid triggering on + // irrelevant bindings (like imports). + if let ExprKind::Name { id, .. } = &value.node { + if self.find_binding(id).map_or(true, |binding| { + matches!( + binding.kind, + BindingKind::Builtin + | BindingKind::ClassDefinition + | BindingKind::FunctionDefinition + | BindingKind::Export(..) + | BindingKind::FutureImportation + | BindingKind::StarImportation(..) + | BindingKind::Importation(..) + | BindingKind::FromImportation(..) + | BindingKind::SubmoduleImportation(..) + ) + }) { + continue; + } + } + + self.add_check(Check::new(code.kind(), Range::from_located(expr))); } - self.add_check(Check::new(code.kind(), Range::from_located(expr))); }; } } @@ -2158,9 +2183,41 @@ where (CheckCode::PD013, "stack"), ] { if self.settings.enabled.contains(&code) { - if let ExprKind::Attribute { attr, .. } = &func.node { + if let ExprKind::Attribute { value, attr, .. } = &func.node { if attr == name { - self.add_check(Check::new(code.kind(), Range::from_located(func))); + if pandas_vet::helpers::is_dataframe_candidate(value) { + // If the target is a named variable, avoid triggering on + // irrelevant bindings (like non-Pandas imports). + if let ExprKind::Name { id, .. } = &value.node { + if self.find_binding(id).map_or(true, |binding| { + if let BindingKind::Importation(.., module) = + &binding.kind + { + module != "pandas" + } else { + matches!( + binding.kind, + BindingKind::Builtin + | BindingKind::ClassDefinition + | BindingKind::FunctionDefinition + | BindingKind::Export(..) + | BindingKind::FutureImportation + | BindingKind::StarImportation(..) + | BindingKind::Importation(..) + | BindingKind::FromImportation(..) + | BindingKind::SubmoduleImportation(..) + ) + } + }) { + continue; + } + } + + self.add_check(Check::new( + code.kind(), + Range::from_located(func), + )); + } }; } } diff --git a/src/pandas_vet/helpers.rs b/src/pandas_vet/helpers.rs new file mode 100644 index 0000000000000..ea3d6ee630649 --- /dev/null +++ b/src/pandas_vet/helpers.rs @@ -0,0 +1,18 @@ +use rustpython_ast::{Expr, ExprKind}; + +/// Return `true` if an `Expr` _could_ be a `DataFrame`. This rules out +/// obviously-wrong cases, like constants and literals. +pub fn is_dataframe_candidate(expr: &Expr) -> bool { + !matches!( + expr.node, + ExprKind::Constant { .. } + | ExprKind::Tuple { .. } + | ExprKind::List { .. } + | ExprKind::Set { .. } + | ExprKind::Dict { .. } + | ExprKind::SetComp { .. } + | ExprKind::ListComp { .. } + | ExprKind::DictComp { .. } + | ExprKind::GeneratorExp { .. } + ) +} diff --git a/src/pandas_vet/mod.rs b/src/pandas_vet/mod.rs index 3ca6d8affc2de..bdfec25fb777e 100644 --- a/src/pandas_vet/mod.rs +++ b/src/pandas_vet/mod.rs @@ -1,4 +1,5 @@ pub mod checks; +pub mod helpers; #[cfg(test)] mod tests { @@ -47,64 +48,201 @@ mod tests { Ok(()) } - #[test_case("df.drop(['a'], axis=1, inplace=False)", &[]; "PD002_pass")] - #[test_case("df.drop(['a'], axis=1, inplace=True)", &[CheckCode::PD002]; "PD002_fail")] - #[test_case("nas = pd.isna(val)", &[]; "PD003_pass")] - #[test_case("nulls = pd.isnull(val)", &[CheckCode::PD003]; "PD003_fail")] - #[test_case("print('bah humbug')", &[]; "PD003_allows_other_calls")] - #[test_case("not_nas = pd.notna(val)", &[]; "PD004_pass")] - #[test_case("not_nulls = pd.notnull(val)", &[CheckCode::PD004]; "PD004_fail")] - #[test_case("new_df = df.loc['d':, 'A':'C']", &[]; "PD007_pass_loc")] - #[test_case("new_df = df.iloc[[1, 3, 5], [1, 3]]", &[]; "PD007_pass_iloc")] - #[test_case("s = df.ix[[0, 2], 'A']", &[CheckCode::PD007]; "PD007_fail")] - #[test_case("index = df.loc[:, ['B', 'A']]", &[]; "PD008_pass")] - #[test_case("index = df.at[:, ['B', 'A']]", &[CheckCode::PD008]; "PD008_fail")] - #[test_case("index = df.iloc[:, 1:3]", &[]; "PD009_pass")] - #[test_case("index = df.iat[:, 1:3]", &[CheckCode::PD009]; "PD009_fail")] - #[test_case(r#"table = df.pivot_table( - df, - values='D', - index=['A', 'B'], - columns=['C'], - aggfunc=np.sum, - fill_value=0 - ) + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + x.drop(['a'], axis=1, inplace=False) + "#, &[]; "PD002_pass")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + x.drop(['a'], axis=1, inplace=True) + "#, &[CheckCode::PD002]; "PD002_fail")] + #[test_case(r#" + import pandas as pd + nas = pd.isna(val) + "#, &[]; "PD003_pass")] + #[test_case(r#" + import pandas as pd + nulls = pd.isnull(val) + "#, &[CheckCode::PD003]; "PD003_fail")] + #[test_case(r#" + import pandas as pd + print('bah humbug') + "#, &[]; "PD003_allows_other_calls")] + #[test_case(r#" + import pandas as pd + not_nas = pd.notna(val) + "#, &[]; "PD004_pass")] + #[test_case(r#" + import pandas as pd + not_nulls = pd.notnull(val) + "#, &[CheckCode::PD004]; "PD004_fail")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + new_x = x.loc['d':, 'A':'C'] + "#, &[]; "PD007_pass_loc")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + new_x = x.iloc[[1, 3, 5], [1, 3]] + "#, &[]; "PD007_pass_iloc")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + y = x.ix[[0, 2], 'A'] + "#, &[CheckCode::PD007]; "PD007_fail")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + index = x.loc[:, ['B', 'A']] + "#, &[]; "PD008_pass")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + index = x.at[:, ['B', 'A']] + "#, &[CheckCode::PD008]; "PD008_fail")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + index = x.iloc[:, 1:3] + "#, &[]; "PD009_pass")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + index = x.iat[:, 1:3] + "#, &[CheckCode::PD009]; "PD009_fail")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + table = x.pivot_table( + x, + values='D', + index=['A', 'B'], + columns=['C'], + aggfunc=np.sum, + fill_value=0 + ) "#, &[]; "PD010_pass")] - #[test_case(r#"table = pd.pivot( - df, - index='foo', - columns='bar', - values='baz' - ) + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + table = pd.pivot( + x, + index='foo', + columns='bar', + values='baz' + ) "#, &[CheckCode::PD010]; "PD010_fail_pivot")] - #[test_case("result = df.to_array()", &[]; "PD011_pass_to_array")] - #[test_case("result = df.array", &[]; "PD011_pass_array")] - #[test_case("result = df.values", &[CheckCode::PD011]; "PD011_fail_values")] - #[test_case("result = df.values()", &[]; "PD011_pass_values_call")] - #[test_case("result = {}.values", &[]; "PD011_pass_values_dict")] - #[test_case("result = values", &[]; "PD011_pass_node_name")] - #[test_case("employees = pd.read_csv(input_file)", &[]; "PD012_pass_read_csv")] - #[test_case("employees = pd.read_table(input_file)", &[CheckCode::PD012]; "PD012_fail_read_table")] - #[test_case("employees = read_table", &[]; "PD012_node_Name_pass")] - #[test_case(r#"table = df.melt( - id_vars='airline', - value_vars=['ATL', 'DEN', 'DFW'], - value_name='airline delay' + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + result = x.to_array() + "#, &[]; "PD011_pass_to_array")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + result = x.array + "#, &[]; "PD011_pass_array")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + result = x.values + "#, &[CheckCode::PD011]; "PD011_fail_values")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + result = x.values() + "#, &[]; "PD011_pass_values_call")] + #[test_case(r#" + import pandas as pd + result = {}.values + "#, &[]; "PD011_pass_values_dict")] + #[test_case(r#" + import pandas as pd + result = pd.values + "#, &[]; "PD011_pass_values_import")] + #[test_case(r#" + import pandas as pd + result = x.values + "#, &[]; "PD011_pass_values_unbound")] + #[test_case(r#" + import pandas as pd + result = values + "#, &[]; "PD011_pass_node_name")] + #[test_case(r#" + import pandas as pd + employees = pd.read_csv(input_file) + "#, &[]; "PD012_pass_read_csv")] + #[test_case(r#" + import pandas as pd + employees = pd.read_table(input_file) + "#, &[CheckCode::PD012]; "PD012_fail_read_table")] + #[test_case(r#" + import pandas as pd + employees = read_table + "#, &[]; "PD012_node_Name_pass")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + y = x.melt( + id_vars='airline', + value_vars=['ATL', 'DEN', 'DFW'], + value_name='airline delay' ) "#, &[]; "PD013_pass")] - #[test_case("table = df.stack(level=-1, dropna=True)", &[CheckCode::PD013]; "PD013_fail_stack")] - #[test_case("df1.merge(df2)", &[]; "PD015_pass_merge_on_dataframe")] - #[test_case("df1.merge(df2, 'inner')", &[]; "PD015_pass_merge_on_dataframe_with_multiple_args")] - #[test_case("pd.merge(df1, df2)", &[CheckCode::PD015]; "PD015_fail_merge_on_pandas_object")] + #[test_case(r#" + import numpy as np + arrays = [np.random.randn(3, 4) for _ in range(10)] + np.stack(arrays, axis=0).shape + "#, &[]; "PD013_pass_numpy")] + #[test_case(r#" + import pandas as pd + y = x.stack(level=-1, dropna=True) + "#, &[]; "PD013_pass_unbound")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + y = x.stack(level=-1, dropna=True) + "#, &[CheckCode::PD013]; "PD013_fail_stack")] + #[test_case(r#" + import pandas as pd + pd.stack( + "#, &[]; "PD015_pass_merge_on_dataframe")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + y = pd.DataFrame() + x.merge(y, 'inner') + "#, &[]; "PD015_pass_merge_on_dataframe_with_multiple_args")] + #[test_case(r#" + import pandas as pd + x = pd.DataFrame() + y = pd.DataFrame() + pd.merge(x, y) + "#, &[CheckCode::PD015]; "PD015_fail_merge_on_pandas_object")] #[test_case( "pd.to_datetime(timestamp * 10 ** 9).strftime('%Y-%m-%d %H:%M:%S.%f')", &[]; "PD015_pass_other_pd_function" )] - #[test_case("employees = pd.DataFrame(employee_dict)", &[]; "PD901_pass_non_df")] - #[test_case("employees_df = pd.DataFrame(employee_dict)", &[]; "PD901_pass_part_df")] - #[test_case("my_function(df=data)", &[]; "PD901_pass_df_param")] - #[test_case("df = pd.DataFrame()", &[CheckCode::PD901]; "PD901_fail_df_var")] + #[test_case(r#" + import pandas as pd + employees = pd.DataFrame(employee_dict) + "#, &[]; "PD901_pass_non_df")] + #[test_case(r#" + import pandas as pd + employees_df = pd.DataFrame(employee_dict) + "#, &[]; "PD901_pass_part_df")] + #[test_case(r#" + import pandas as pd + my_function(df=data) + "#, &[]; "PD901_pass_df_param")] + #[test_case(r#" + import pandas as pd + df = pd.DataFrame() + "#, &[CheckCode::PD901]; "PD901_fail_df_var")] fn test_pandas_vet(code: &str, expected: &[CheckCode]) -> Result<()> { check_code(code, expected)?; Ok(())