Skip to content

Commit

Permalink
Add support for methods taking literal constant args in Access Paths. (
Browse files Browse the repository at this point in the history
…#285)

Consider the code below:

```
if (x.get(0) != null && x.get(0).foo() != null) {
   return x.get(0).foo().bar(); 
}
```

This code is safe, but NullAway would miss that before this change, due to `get(0)` taking an argument. We handle `foo()` (a zero-arguments method) just fine. 

This patch extends our support for method calls with only literal values (and boxed literal values) being passed as arguments. I don't see a case where this would add unsoundness that isn't present on the zero-args case.
  • Loading branch information
lazaroclapp committed Mar 14, 2019
1 parent 4a4b591 commit 0b9cdef
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 25 deletions.
77 changes: 59 additions & 18 deletions nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java
Expand Up @@ -25,6 +25,9 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.util.ASTHelpers;
import com.sun.source.tree.LiteralTree;
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.Tree;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.code.Types;
Expand All @@ -45,29 +48,30 @@
* Represents an extended notion of an access path, which we track for nullness.
*
* <p>Typically, access paths are of the form x.f.g.h, where x is a variable and f, g, and h are
* field names. Here, we also allow no-argument methods to appear in the access path, so it can be
* of the form x.f().g.h() in general.
* field names. Here, we also allow no-argument methods to appear in the access path, as well as
* method calls passed only statically constant parameters, so an AP can be of the form
* x.f().g.h([int_expr|string_expr]) in general.
*
* <p>We do not allow array accesses in access paths for the moment.
*/
public final class AccessPath {

private final Root root;

private final ImmutableList<Element> elements;
private final ImmutableList<AccessPathElement> elements;

/**
* if present, the argument to the map get() method call that is the final element of this path
*/
@Nullable private final AccessPath mapGetArgAccessPath;

AccessPath(Root root, List<Element> elements) {
AccessPath(Root root, List<AccessPathElement> elements) {
this.root = root;
this.elements = ImmutableList.copyOf(elements);
this.mapGetArgAccessPath = null;
}

private AccessPath(Root root, List<Element> elements, AccessPath mapGetArgAccessPath) {
private AccessPath(Root root, List<AccessPathElement> elements, AccessPath mapGetArgAccessPath) {
this.root = root;
this.elements = ImmutableList.copyOf(elements);
this.mapGetArgAccessPath = mapGetArgAccessPath;
Expand Down Expand Up @@ -96,7 +100,7 @@ static AccessPath fromVarDecl(VariableDeclarationNode node) {
*/
@Nullable
static AccessPath fromFieldAccess(FieldAccessNode node) {
List<Element> elements = new ArrayList<>();
List<AccessPathElement> elements = new ArrayList<>();
Root root = populateElementsRec(node, elements);
return (root != null) ? new AccessPath(root, elements) : null;
}
Expand All @@ -115,7 +119,7 @@ static AccessPath fromMethodCall(MethodInvocationNode node, @Nullable Types type

@Nullable
private static AccessPath fromVanillaMethodCall(MethodInvocationNode node) {
List<Element> elements = new ArrayList<>();
List<AccessPathElement> elements = new ArrayList<>();
Root root = populateElementsRec(node, elements);
return (root != null) ? new AccessPath(root, elements) : null;
}
Expand All @@ -127,12 +131,12 @@ private static AccessPath fromVanillaMethodCall(MethodInvocationNode node) {
*/
@Nullable
public static AccessPath fromBaseAndElement(Node base, Element element) {
List<Element> elements = new ArrayList<>();
List<AccessPathElement> elements = new ArrayList<>();
Root root = populateElementsRec(base, elements);
if (root == null) {
return null;
}
elements.add(element);
elements.add(new AccessPathElement(element));
return new AccessPath(root, elements);
}

Expand Down Expand Up @@ -161,7 +165,7 @@ private static AccessPath fromMapGetCall(MethodInvocationNode node) {
}
MethodAccessNode target = node.getTarget();
Node receiver = target.getReceiver();
List<Element> elements = new ArrayList<>();
List<AccessPathElement> elements = new ArrayList<>();
Root root = populateElementsRec(receiver, elements);
if (root == null) {
return null;
Expand Down Expand Up @@ -202,8 +206,14 @@ public static AccessPath getAccessPathForNodeWithMapGet(Node node, @Nullable Typ
}
}

private static boolean isBoxingMethod(Symbol.MethodSymbol methodSymbol) {
return methodSymbol.isStatic()
&& methodSymbol.getSimpleName().contentEquals("valueOf")
&& methodSymbol.enclClass().packge().fullname.contentEquals("java.lang");
}

@Nullable
private static Root populateElementsRec(Node node, List<Element> elements) {
private static Root populateElementsRec(Node node, List<AccessPathElement> elements) {
Root result;
if (node instanceof FieldAccessNode) {
FieldAccessNode fieldAccess = (FieldAccessNode) node;
Expand All @@ -213,17 +223,48 @@ private static Root populateElementsRec(Node node, List<Element> elements) {
} else {
// instance field access
result = populateElementsRec(fieldAccess.getReceiver(), elements);
elements.add(fieldAccess.getElement());
elements.add(new AccessPathElement(fieldAccess.getElement()));
}
} else if (node instanceof MethodInvocationNode) {
MethodInvocationNode invocation = (MethodInvocationNode) node;
// only support zero-argument methods
if (invocation.getArguments().size() > 0) {
return null;
}
AccessPathElement accessPathElement;
MethodAccessNode accessNode = invocation.getTarget();
if (invocation.getArguments().size() == 0) {
accessPathElement = new AccessPathElement(accessNode.getMethod());
} else {
List<String> constantArgumentValues = new ArrayList<>();
for (Node argumentNode : invocation.getArguments()) {
Tree tree = argumentNode.getTree();
if (tree == null) {
return null; // Not an AP
} else if (tree.getKind().equals(Tree.Kind.METHOD_INVOCATION)) {
// Check for boxing call
MethodInvocationTree methodInvocationTree = (MethodInvocationTree) tree;
if (methodInvocationTree.getArguments().size() == 1
&& isBoxingMethod(ASTHelpers.getSymbol(methodInvocationTree))) {
tree = methodInvocationTree.getArguments().get(0);
}
}
switch (tree.getKind()) {
case BOOLEAN_LITERAL:
case CHAR_LITERAL:
case DOUBLE_LITERAL:
case FLOAT_LITERAL:
case INT_LITERAL:
case LONG_LITERAL:
case STRING_LITERAL:
constantArgumentValues.add(((LiteralTree) tree).getValue().toString());
break;
case NULL_LITERAL:
// Um, probably not? Cascade to default for now.
default:
return null; // Not an AP
}
}
accessPathElement = new AccessPathElement(accessNode.getMethod(), constantArgumentValues);
}
result = populateElementsRec(accessNode.getReceiver(), elements);
elements.add(accessNode.getMethod());
elements.add(accessPathElement);
} else if (node instanceof LocalVariableNode) {
result = new Root(((LocalVariableNode) node).getElement());
} else if (node instanceof ThisLiteralNode) {
Expand Down Expand Up @@ -269,7 +310,7 @@ public Root getRoot() {
return root;
}

public ImmutableList<Element> getElements() {
public ImmutableList<AccessPathElement> getElements() {
return elements;
}

Expand Down
@@ -0,0 +1,67 @@
package com.uber.nullaway.dataflow;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nullable;
import javax.lang.model.element.Element;

/**
* Represents a (non-root) element of an AccessPath.
*
* <p>This is just a java Element (field, method, etc) in the access-path chain (e.g. f or g() in
* x.f.g()). Plus, optionally, a list of constant arguments, allowing access path elements for
* method calls with constant values (e.g. h(3) or k("STR_KEY") in x.h(3).g().k("STR_KEY")).
*/
public final class AccessPathElement {
private final Element javaElement;
@Nullable private final ImmutableList<String> constantArguments;

public AccessPathElement(Element javaElement, List<String> constantArguments) {
this.javaElement = javaElement;
this.constantArguments = ImmutableList.copyOf(constantArguments);
}

public AccessPathElement(Element javaElement) {
this.javaElement = javaElement;
this.constantArguments = null;
}

public Element getJavaElement() {
return this.javaElement;
}

public ImmutableList<String> getConstantArguments() {
return this.constantArguments;
}

@Override
public boolean equals(Object obj) {
if (obj instanceof AccessPathElement) {
AccessPathElement otherNode = (AccessPathElement) obj;
return this.javaElement.equals(otherNode.javaElement)
&& (constantArguments == null
? otherNode.constantArguments == null
: constantArguments.equals(otherNode.constantArguments));
} else {
return false;
}
}

@Override
public int hashCode() {
int result = javaElement.hashCode();
result = 31 * result + (constantArguments != null ? constantArguments.hashCode() : 0);
return result;
}

@Override
public String toString() {
return "APElement{"
+ "javaElement="
+ javaElement.toString()
+ ", constantArguments="
+ Arrays.deepToString(constantArguments.toArray())
+ '}';
}
}
Expand Up @@ -112,9 +112,9 @@ private Set<Element> getNonnullReceiverFields(NullnessStore nullnessResult) {
Set<Element> result = new LinkedHashSet<>();
for (AccessPath ap : nonnullAccessPaths) {
if (ap.getRoot().isReceiver()) {
ImmutableList<Element> elements = ap.getElements();
ImmutableList<AccessPathElement> elements = ap.getElements();
if (elements.size() == 1) {
Element elem = elements.get(0);
Element elem = elements.get(0).getJavaElement();
if (elem.getKind().equals(ElementKind.FIELD)) {
result.add(elem);
}
Expand Down
Expand Up @@ -120,9 +120,11 @@ public boolean includeApInfoInSavedContext(AccessPath accessPath, VisitorState s

if (accessPath.getElements().size() == 1) {
AccessPath.Root root = accessPath.getRoot();
if (!root.isReceiver() && (accessPath.getElements().get(0) instanceof Symbol.MethodSymbol)) {
if (!root.isReceiver()
&& (accessPath.getElements().get(0).getJavaElement() instanceof Symbol.MethodSymbol)) {
final Element e = root.getVarElement();
final Symbol.MethodSymbol g = (Symbol.MethodSymbol) accessPath.getElements().get(0);
final Symbol.MethodSymbol g =
(Symbol.MethodSymbol) accessPath.getElements().get(0).getJavaElement();
return e.getKind().equals(ElementKind.LOCAL_VARIABLE)
&& optionalIsGetCall(g, state.getTypes());
}
Expand Down
Expand Up @@ -49,6 +49,7 @@
import com.uber.nullaway.NullabilityUtil;
import com.uber.nullaway.Nullness;
import com.uber.nullaway.dataflow.AccessPath;
import com.uber.nullaway.dataflow.AccessPathElement;
import com.uber.nullaway.dataflow.AccessPathNullnessAnalysis;
import com.uber.nullaway.dataflow.NullnessStore;
import java.util.ArrayList;
Expand Down Expand Up @@ -393,10 +394,10 @@ public void onMatchMethodReference(
assert filterNullnessStore != null;
for (AccessPath ap : filterNullnessStore.getAccessPathsWithValue(Nullness.NONNULL)) {
// Find the access path corresponding to the current unbound method reference after binding
ImmutableList<Element> elements = ap.getElements();
ImmutableList<AccessPathElement> elements = ap.getElements();
if (elements.size() == 1) {
// We only care for single method call chains (e.g. this.foo(), not this.f.bar())
Element element = elements.get(0);
Element element = elements.get(0).getJavaElement();
if (!element.getKind().equals(ElementKind.METHOD)) {
// We are only looking for method APs
continue;
Expand Down
86 changes: 85 additions & 1 deletion nullaway/src/test/java/com/uber/nullaway/NullAwayTest.java
Expand Up @@ -584,7 +584,6 @@ public void testThriftIsSet() {
" }",
" java.util.List<Generated> l = new java.util.ArrayList<>();",
" if (l.get(0).isSetId()) {",
" // BUG: Diagnostic contains: dereferenced expression l.get(0).getId()",
" l.get(0).getId().hashCode();",
" }",
" }",
Expand Down Expand Up @@ -1431,4 +1430,89 @@ public void testEnhancedFor() {
"}")
.doTest();
}

@Test
public void testConstantsInAccessPathsNegative() {
compilationHelper
.addSourceLines(
"NullableContainer.java",
"package com.uber;",
"import javax.annotation.Nullable;",
"public interface NullableContainer<K, V> {",
" @Nullable public V get(K k);",
"}")
.addSourceLines(
"Test.java",
"package com.uber;",
"import javax.annotation.Nullable;",
"public class Test {",
" public void testSingleStringCheck(NullableContainer<String, Object> c) {",
" if (c.get(\"KEY_STR\") != null) {",
" c.get(\"KEY_STR\").toString(); // is safe",
" }",
" }",
" public void testSingleIntCheck(NullableContainer<Integer, Object> c) {",
" if (c.get(42) != null) {",
" c.get(42).toString(); // is safe",
" }",
" }",
" public void testMultipleChecks(NullableContainer<String, NullableContainer<Integer, Object>> c) {",
" if (c.get(\"KEY_STR\") != null && c.get(\"KEY_STR\").get(42) != null) {",
" c.get(\"KEY_STR\").get(42).toString(); // is safe",
" }",
" }",
"}")
.doTest();
}

@Test
public void testConstantsInAccessPathsPositive() {
compilationHelper
.addSourceLines(
"NullableContainer.java",
"package com.uber;",
"import javax.annotation.Nullable;",
"public interface NullableContainer<K, V> {",
" @Nullable public V get(K k);",
"}")
.addSourceLines(
"Test.java",
"package com.uber;",
"import javax.annotation.Nullable;",
"public class Test {",
" public void testEnhancedFor(NullableContainer<String, NullableContainer<Integer, Object>> c) {",
" if (c.get(\"KEY_STR\") != null && c.get(\"KEY_STR\").get(0) != null) {",
" // BUG: Diagnostic contains: dereferenced expression c.get(\"KEY_STR\").get(42)",
" c.get(\"KEY_STR\").get(42).toString();",
" }",
" }",
"}")
.doTest();
}

@Test
public void testVariablesInAccessPathsPositive() {
compilationHelper
.addSourceLines(
"NullableContainer.java",
"package com.uber;",
"import javax.annotation.Nullable;",
"public interface NullableContainer<K, V> {",
" @Nullable public V get(K k);",
"}")
.addSourceLines(
"Test.java",
"package com.uber;",
"import javax.annotation.Nullable;",
"public class Test {",
" private Integer intKey = 42;", // No guarantee it's a constant
" public void testEnhancedFor(NullableContainer<String, NullableContainer<Integer, Object>> c) {",
" if (c.get(\"KEY_STR\") != null && c.get(\"KEY_STR\").get(this.intKey) != null) {",
" // BUG: Diagnostic contains: dereferenced expression c.get(\"KEY_STR\").get",
" c.get(\"KEY_STR\").get(this.intKey).toString();",
" }",
" }",
"}")
.doTest();
}
}

0 comments on commit 0b9cdef

Please sign in to comment.