Skip to content

Commit

Permalink
Add missed shortcircuit lint for collect calls
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusklaas committed Oct 26, 2017
1 parent 0b0fe69 commit d18648e
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 0 deletions.
194 changes: 194 additions & 0 deletions clippy_lints/src/collect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
use itertools::{repeat_n, Itertools};
use rustc::hir::*;
use rustc::lint::*;
use rustc::ty::TypeVariants;
use syntax::ast::NodeId;

use std::collections::HashSet;

use utils::{match_trait_method, match_type, span_lint_and_sugg};
use utils::paths;

/// **What it does:** Detects collect calls on iterators to collections
/// of either `Result<_, E>` or `Option<_>` inside functions that also
/// have such a return type.
///
/// **Why is this bad?** It is possible to short-circuit these collect
/// calls and return early whenever a `None` or `Err(E)` is encountered.
///
/// **Known problems:** It may be possible that a collection of options
/// or results is intended. This would then generate a false positive.
///
/// **Example:**
/// ```rust
/// pub fn div(a: i32, b: &[i32]) -> Result<Vec<i32>, String> {
/// let option_vec: Vec<_> = b.into_iter()
/// .cloned()
/// .map(|i| if i != 0 {
/// Ok(a / i)
/// } else {
/// Err("Division by zero!".to_owned())
/// })
/// .collect();
/// let mut int_vec = Vec::new();
/// for opt in option_vec {
/// int_vec.push(opt?);
/// }
/// Ok(int_vec)
/// }
/// ```
declare_lint! {
pub POSSIBLE_SHORTCIRCUITING_COLLECT,
Allow,
"missed shortcircuit opportunity on collect"
}

#[derive(Clone)]
pub struct Pass {
// To ensure that we do not lint the same expression more than once
seen_expr_nodes: HashSet<NodeId>,
}

impl Pass {
pub fn new() -> Self {
Self { seen_expr_nodes: HashSet::new() }
}
}

impl LintPass for Pass {
fn get_lints(&self) -> LintArray {
lint_array!(POSSIBLE_SHORTCIRCUITING_COLLECT)
}
}

struct Suggestion {
pattern: String,
type_colloquial: &'static str,
success_variant: &'static str,
}

fn format_suggestion_pattern<'a, 'tcx>(
cx: &LateContext<'a, 'tcx>,
collection_ty: TypeVariants,
is_option: bool,
) -> String {
let collection_pat = match collection_ty {
TypeVariants::TyAdt(def, subs) => {
let mut buf = cx.tcx.item_path_str(def.did);

if !subs.is_empty() {
buf.push('<');
buf.push_str(&repeat_n('_', subs.len()).join(", "));
buf.push('>');
}

buf
},
TypeVariants::TyParam(p) => p.to_string(),
_ => "_".into(),
};

if is_option {
format!("Option<{}>", collection_pat)
} else {
format!("Result<{}, _>", collection_pat)
}
}

fn check_expr_for_collect<'a, 'tcx>(cx: &LateContext<'a, 'tcx>, expr: &'tcx Expr) -> Option<Suggestion> {
if let ExprMethodCall(ref method, _, ref args) = expr.node {
if args.len() == 1 && method.name == "collect" && match_trait_method(cx, expr, &paths::ITERATOR) {
let collect_ty = cx.tables.expr_ty(expr);

if match_type(cx, collect_ty, &paths::OPTION) || match_type(cx, collect_ty, &paths::RESULT) {
// Already collecting into an Option or Result - good!
return None;
}

// Get the type of the Item associated to the Iterator on which collect() is
// called.
let arg_ty = cx.tables.expr_ty(&args[0]);
let method_call = cx.tables.type_dependent_defs()[args[0].hir_id];
let trt_id = cx.tcx.trait_of_item(method_call.def_id()).unwrap();
let assoc_item_id = cx.tcx.associated_items(trt_id).next().unwrap().def_id;
let substitutions = cx.tcx.mk_substs_trait(arg_ty, &[]);
let projection = cx.tcx.mk_projection(assoc_item_id, substitutions);
let normal_ty = cx.tcx.normalize_associated_type_in_env(
&projection,
cx.param_env,
);

return if match_type(cx, normal_ty, &paths::OPTION) {
Some(Suggestion {
pattern: format_suggestion_pattern(cx, collect_ty.sty.clone(), true),
type_colloquial: "Option",
success_variant: "Some",
})
} else if match_type(cx, normal_ty, &paths::RESULT) {
Some(Suggestion {
pattern: format_suggestion_pattern(cx, collect_ty.sty.clone(), false),
type_colloquial: "Result",
success_variant: "Ok",
})
} else {
None
};
}
}

None
}

impl<'a, 'tcx> LateLintPass<'a, 'tcx> for Pass {
fn check_expr(&mut self, cx: &LateContext<'a, 'tcx>, expr: &'tcx Expr) {
if self.seen_expr_nodes.contains(&expr.id) {
return;
}

if let Some(suggestion) = check_expr_for_collect(cx, expr) {
let sugg_span = if let ExprMethodCall(_, call_span, _) = expr.node {
expr.span.between(call_span)
} else {
unreachable!()
};

span_lint_and_sugg(
cx,
POSSIBLE_SHORTCIRCUITING_COLLECT,
sugg_span,
&format!("you are creating a collection of `{}`s", suggestion.type_colloquial),
&format!(
"if you are only interested in the case where all values are `{}`, try",
suggestion.success_variant
),
format!("collect::<{}>()", suggestion.pattern),
);
}
}

fn check_stmt(&mut self, cx: &LateContext<'a, 'tcx>, stmt: &'tcx Stmt) {
if_chain! {
if let StmtDecl(ref decl, _) = stmt.node;
if let DeclLocal(ref local) = decl.node;
if let Some(ref ty) = local.ty;
if let Some(ref expr) = local.init;
then {
self.seen_expr_nodes.insert(expr.id);

if let Some(suggestion) = check_expr_for_collect(cx, expr) {
span_lint_and_sugg(
cx,
POSSIBLE_SHORTCIRCUITING_COLLECT,
ty.span,
&format!("you are creating a collection of `{}`s", suggestion.type_colloquial),
&format!(
"if you are only interested in the case where all values are `{}`, try",
suggestion.success_variant
),
suggestion.pattern
);
}
}
}
}
}
3 changes: 3 additions & 0 deletions clippy_lints/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pub mod booleans;
pub mod bytecount;
pub mod collapsible_if;
pub mod const_static_lifetime;
pub mod collect;
pub mod copies;
pub mod cyclomatic_complexity;
pub mod derive;
Expand Down Expand Up @@ -350,6 +351,7 @@ pub fn register_plugins(reg: &mut rustc_plugin::Registry) {
reg.register_late_lint_pass(box types::ImplicitHasher);
reg.register_early_lint_pass(box const_static_lifetime::StaticConst);
reg.register_late_lint_pass(box fallible_impl_from::FallibleImplFrom);
reg.register_late_lint_pass(box collect::Pass::new());

reg.register_lint_group("clippy_restrictions", vec![
arithmetic::FLOAT_ARITHMETIC,
Expand Down Expand Up @@ -423,6 +425,7 @@ pub fn register_plugins(reg: &mut rustc_plugin::Registry) {
booleans::LOGIC_BUG,
bytecount::NAIVE_BYTECOUNT,
collapsible_if::COLLAPSIBLE_IF,
collect::POSSIBLE_SHORTCIRCUITING_COLLECT,
copies::IF_SAME_THEN_ELSE,
copies::IFS_SAME_COND,
copies::MATCH_SAME_ARMS,
Expand Down
38 changes: 38 additions & 0 deletions tests/ui/collect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#![feature(plugin, inclusive_range_syntax)]
#![plugin(clippy)]

use std::iter::FromIterator;

#[warn(possible_shortcircuiting_collect)]
pub fn div(a: i32, b: &[i32]) -> Result<Vec<i32>, String> {
let option_vec: Vec<_> = b.into_iter()
.cloned()
.map(|i| if i != 0 {
Ok(a / i)
} else {
Err("Division by zero!".to_owned())
})
.collect();
let mut int_vec = Vec::new();
for opt in option_vec {
int_vec.push(opt?);
}
Ok(int_vec)
}

#[warn(possible_shortcircuiting_collect)]
pub fn generic<T>(a: &[T]) {
// Make sure that our lint also works for generic functions.
let _result: Vec<_> = a.iter().map(Some).collect();
}

#[warn(possible_shortcircuiting_collect)]
pub fn generic_collection<T, C: FromIterator<T> + FromIterator<Option<T>>>(elem: T) -> C {
Some(Some(elem)).into_iter().collect()
}

#[warn(possible_shortcircuiting_collect)]
fn main() {
// We're collecting into an `Option`. Do not trigger lint.
let _sup: Option<Vec<_>> = (0..5).map(Some).collect();
}
36 changes: 36 additions & 0 deletions tests/ui/collect.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
warning: running cargo clippy on a crate that also imports the clippy plugin

error: you are creating a collection of `Result`s
--> $DIR/collect.rs:8:21
|
8 | let option_vec: Vec<_> = b.into_iter()
| ^^^^^^
|
= note: `-D possible-shortcircuiting-collect` implied by `-D warnings`
help: if you are only interested in the case where all values are `Ok`, try
|
8 | let option_vec: Result<std::vec::Vec<_>, _> = b.into_iter()
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: you are creating a collection of `Option`s
--> $DIR/collect.rs:26:18
|
26 | let _result: Vec<_> = a.iter().map(Some).collect();
| ^^^^^^
|
help: if you are only interested in the case where all values are `Some`, try
|
26 | let _result: Option<std::vec::Vec<_>> = a.iter().map(Some).collect();
| ^^^^^^^^^^^^^^^^^^^^^^^^

error: you are creating a collection of `Option`s
--> $DIR/collect.rs:31:34
|
31 | Some(Some(elem)).into_iter().collect()
| ^^^^^^^^^
|
help: if you are only interested in the case where all values are `Some`, try
|
31 | Some(Some(elem)).into_iter().collect::<Option<C>>()
| ^^^^^^^^^^^^^^^^^^^^^^

0 comments on commit d18648e

Please sign in to comment.