Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for methods taking literal constant args in Access Paths. #285

Merged
merged 1 commit into from Mar 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 59 additions & 18 deletions nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java
Expand Up @@ -25,6 +25,9 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.util.ASTHelpers;
import com.sun.source.tree.LiteralTree;
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.Tree;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.code.Types;
Expand All @@ -45,29 +48,30 @@
* Represents an extended notion of an access path, which we track for nullness.
*
* <p>Typically, access paths are of the form x.f.g.h, where x is a variable and f, g, and h are
* field names. Here, we also allow no-argument methods to appear in the access path, so it can be
* of the form x.f().g.h() in general.
* field names. Here, we also allow no-argument methods to appear in the access path, as well as
* method calls passed only statically constant parameters, so an AP can be of the form
* x.f().g.h([int_expr|string_expr]) in general.
*
* <p>We do not allow array accesses in access paths for the moment.
*/
public final class AccessPath {

private final Root root;

private final ImmutableList<Element> elements;
private final ImmutableList<AccessPathElement> elements;

/**
* if present, the argument to the map get() method call that is the final element of this path
*/
@Nullable private final AccessPath mapGetArgAccessPath;

AccessPath(Root root, List<Element> elements) {
AccessPath(Root root, List<AccessPathElement> elements) {
this.root = root;
this.elements = ImmutableList.copyOf(elements);
this.mapGetArgAccessPath = null;
}

private AccessPath(Root root, List<Element> elements, AccessPath mapGetArgAccessPath) {
private AccessPath(Root root, List<AccessPathElement> elements, AccessPath mapGetArgAccessPath) {
this.root = root;
this.elements = ImmutableList.copyOf(elements);
this.mapGetArgAccessPath = mapGetArgAccessPath;
Expand Down Expand Up @@ -96,7 +100,7 @@ static AccessPath fromVarDecl(VariableDeclarationNode node) {
*/
@Nullable
static AccessPath fromFieldAccess(FieldAccessNode node) {
List<Element> elements = new ArrayList<>();
List<AccessPathElement> elements = new ArrayList<>();
Root root = populateElementsRec(node, elements);
return (root != null) ? new AccessPath(root, elements) : null;
}
Expand All @@ -115,7 +119,7 @@ static AccessPath fromMethodCall(MethodInvocationNode node, @Nullable Types type

@Nullable
private static AccessPath fromVanillaMethodCall(MethodInvocationNode node) {
List<Element> elements = new ArrayList<>();
List<AccessPathElement> elements = new ArrayList<>();
Root root = populateElementsRec(node, elements);
return (root != null) ? new AccessPath(root, elements) : null;
}
Expand All @@ -127,12 +131,12 @@ private static AccessPath fromVanillaMethodCall(MethodInvocationNode node) {
*/
@Nullable
public static AccessPath fromBaseAndElement(Node base, Element element) {
List<Element> elements = new ArrayList<>();
List<AccessPathElement> elements = new ArrayList<>();
Root root = populateElementsRec(base, elements);
if (root == null) {
return null;
}
elements.add(element);
elements.add(new AccessPathElement(element));
return new AccessPath(root, elements);
}

Expand Down Expand Up @@ -161,7 +165,7 @@ private static AccessPath fromMapGetCall(MethodInvocationNode node) {
}
MethodAccessNode target = node.getTarget();
Node receiver = target.getReceiver();
List<Element> elements = new ArrayList<>();
List<AccessPathElement> elements = new ArrayList<>();
Root root = populateElementsRec(receiver, elements);
if (root == null) {
return null;
Expand Down Expand Up @@ -202,8 +206,14 @@ public static AccessPath getAccessPathForNodeWithMapGet(Node node, @Nullable Typ
}
}

private static boolean isBoxingMethod(Symbol.MethodSymbol methodSymbol) {
return methodSymbol.isStatic()
&& methodSymbol.getSimpleName().contentEquals("valueOf")
&& methodSymbol.enclClass().packge().fullname.contentEquals("java.lang");
}

@Nullable
private static Root populateElementsRec(Node node, List<Element> elements) {
private static Root populateElementsRec(Node node, List<AccessPathElement> elements) {
Root result;
if (node instanceof FieldAccessNode) {
FieldAccessNode fieldAccess = (FieldAccessNode) node;
Expand All @@ -213,17 +223,48 @@ private static Root populateElementsRec(Node node, List<Element> elements) {
} else {
// instance field access
result = populateElementsRec(fieldAccess.getReceiver(), elements);
elements.add(fieldAccess.getElement());
elements.add(new AccessPathElement(fieldAccess.getElement()));
}
} else if (node instanceof MethodInvocationNode) {
MethodInvocationNode invocation = (MethodInvocationNode) node;
// only support zero-argument methods
if (invocation.getArguments().size() > 0) {
return null;
}
AccessPathElement accessPathElement;
MethodAccessNode accessNode = invocation.getTarget();
if (invocation.getArguments().size() == 0) {
accessPathElement = new AccessPathElement(accessNode.getMethod());
} else {
List<String> constantArgumentValues = new ArrayList<>();
for (Node argumentNode : invocation.getArguments()) {
Tree tree = argumentNode.getTree();
if (tree == null) {
return null; // Not an AP
} else if (tree.getKind().equals(Tree.Kind.METHOD_INVOCATION)) {
// Check for boxing call
MethodInvocationTree methodInvocationTree = (MethodInvocationTree) tree;
if (methodInvocationTree.getArguments().size() == 1
&& isBoxingMethod(ASTHelpers.getSymbol(methodInvocationTree))) {
tree = methodInvocationTree.getArguments().get(0);
}
}
switch (tree.getKind()) {
case BOOLEAN_LITERAL:
case CHAR_LITERAL:
case DOUBLE_LITERAL:
case FLOAT_LITERAL:
case INT_LITERAL:
case LONG_LITERAL:
case STRING_LITERAL:
constantArgumentValues.add(((LiteralTree) tree).getValue().toString());
break;
case NULL_LITERAL:
// Um, probably not? Cascade to default for now.
default:
return null; // Not an AP
}
}
accessPathElement = new AccessPathElement(accessNode.getMethod(), constantArgumentValues);
}
result = populateElementsRec(accessNode.getReceiver(), elements);
elements.add(accessNode.getMethod());
elements.add(accessPathElement);
} else if (node instanceof LocalVariableNode) {
result = new Root(((LocalVariableNode) node).getElement());
} else if (node instanceof ThisLiteralNode) {
Expand Down Expand Up @@ -269,7 +310,7 @@ public Root getRoot() {
return root;
}

public ImmutableList<Element> getElements() {
public ImmutableList<AccessPathElement> getElements() {
return elements;
}

Expand Down
@@ -0,0 +1,67 @@
package com.uber.nullaway.dataflow;

import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nullable;
import javax.lang.model.element.Element;

/**
* Represents a (non-root) element of an AccessPath.
*
* <p>This is just a java Element (field, method, etc) in the access-path chain (e.g. f or g() in
* x.f.g()). Plus, optionally, a list of constant arguments, allowing access path elements for
* method calls with constant values (e.g. h(3) or k("STR_KEY") in x.h(3).g().k("STR_KEY")).
*/
public final class AccessPathElement {
private final Element javaElement;
@Nullable private final ImmutableList<String> constantArguments;

public AccessPathElement(Element javaElement, List<String> constantArguments) {
this.javaElement = javaElement;
this.constantArguments = ImmutableList.copyOf(constantArguments);
}

public AccessPathElement(Element javaElement) {
this.javaElement = javaElement;
this.constantArguments = null;
}

public Element getJavaElement() {
return this.javaElement;
}

public ImmutableList<String> getConstantArguments() {
return this.constantArguments;
}

@Override
public boolean equals(Object obj) {
if (obj instanceof AccessPathElement) {
AccessPathElement otherNode = (AccessPathElement) obj;
return this.javaElement.equals(otherNode.javaElement)
&& (constantArguments == null
? otherNode.constantArguments == null
: constantArguments.equals(otherNode.constantArguments));
} else {
return false;
}
}

@Override
public int hashCode() {
int result = javaElement.hashCode();
result = 31 * result + (constantArguments != null ? constantArguments.hashCode() : 0);
return result;
}

@Override
public String toString() {
return "APElement{"
+ "javaElement="
+ javaElement.toString()
+ ", constantArguments="
+ Arrays.deepToString(constantArguments.toArray())
+ '}';
}
}
Expand Up @@ -112,9 +112,9 @@ private Set<Element> getNonnullReceiverFields(NullnessStore nullnessResult) {
Set<Element> result = new LinkedHashSet<>();
for (AccessPath ap : nonnullAccessPaths) {
if (ap.getRoot().isReceiver()) {
ImmutableList<Element> elements = ap.getElements();
ImmutableList<AccessPathElement> elements = ap.getElements();
if (elements.size() == 1) {
Element elem = elements.get(0);
Element elem = elements.get(0).getJavaElement();
if (elem.getKind().equals(ElementKind.FIELD)) {
result.add(elem);
}
Expand Down
Expand Up @@ -120,9 +120,11 @@ public boolean includeApInfoInSavedContext(AccessPath accessPath, VisitorState s

if (accessPath.getElements().size() == 1) {
AccessPath.Root root = accessPath.getRoot();
if (!root.isReceiver() && (accessPath.getElements().get(0) instanceof Symbol.MethodSymbol)) {
if (!root.isReceiver()
&& (accessPath.getElements().get(0).getJavaElement() instanceof Symbol.MethodSymbol)) {
final Element e = root.getVarElement();
final Symbol.MethodSymbol g = (Symbol.MethodSymbol) accessPath.getElements().get(0);
final Symbol.MethodSymbol g =
(Symbol.MethodSymbol) accessPath.getElements().get(0).getJavaElement();
return e.getKind().equals(ElementKind.LOCAL_VARIABLE)
&& optionalIsGetCall(g, state.getTypes());
}
Expand Down
Expand Up @@ -49,6 +49,7 @@
import com.uber.nullaway.NullabilityUtil;
import com.uber.nullaway.Nullness;
import com.uber.nullaway.dataflow.AccessPath;
import com.uber.nullaway.dataflow.AccessPathElement;
import com.uber.nullaway.dataflow.AccessPathNullnessAnalysis;
import com.uber.nullaway.dataflow.NullnessStore;
import java.util.ArrayList;
Expand Down Expand Up @@ -393,10 +394,10 @@ public void onMatchMethodReference(
assert filterNullnessStore != null;
for (AccessPath ap : filterNullnessStore.getAccessPathsWithValue(Nullness.NONNULL)) {
// Find the access path corresponding to the current unbound method reference after binding
ImmutableList<Element> elements = ap.getElements();
ImmutableList<AccessPathElement> elements = ap.getElements();
if (elements.size() == 1) {
// We only care for single method call chains (e.g. this.foo(), not this.f.bar())
Element element = elements.get(0);
Element element = elements.get(0).getJavaElement();
if (!element.getKind().equals(ElementKind.METHOD)) {
// We are only looking for method APs
continue;
Expand Down
86 changes: 85 additions & 1 deletion nullaway/src/test/java/com/uber/nullaway/NullAwayTest.java
Expand Up @@ -584,7 +584,6 @@ public void testThriftIsSet() {
" }",
" java.util.List<Generated> l = new java.util.ArrayList<>();",
" if (l.get(0).isSetId()) {",
" // BUG: Diagnostic contains: dereferenced expression l.get(0).getId()",
" l.get(0).getId().hashCode();",
" }",
" }",
Expand Down Expand Up @@ -1431,4 +1430,89 @@ public void testEnhancedFor() {
"}")
.doTest();
}

@Test
public void testConstantsInAccessPathsNegative() {
compilationHelper
.addSourceLines(
"NullableContainer.java",
"package com.uber;",
"import javax.annotation.Nullable;",
"public interface NullableContainer<K, V> {",
" @Nullable public V get(K k);",
"}")
.addSourceLines(
"Test.java",
"package com.uber;",
"import javax.annotation.Nullable;",
"public class Test {",
" public void testSingleStringCheck(NullableContainer<String, Object> c) {",
" if (c.get(\"KEY_STR\") != null) {",
" c.get(\"KEY_STR\").toString(); // is safe",
" }",
" }",
" public void testSingleIntCheck(NullableContainer<Integer, Object> c) {",
" if (c.get(42) != null) {",
" c.get(42).toString(); // is safe",
" }",
" }",
" public void testMultipleChecks(NullableContainer<String, NullableContainer<Integer, Object>> c) {",
" if (c.get(\"KEY_STR\") != null && c.get(\"KEY_STR\").get(42) != null) {",
" c.get(\"KEY_STR\").get(42).toString(); // is safe",
" }",
" }",
"}")
.doTest();
}

@Test
public void testConstantsInAccessPathsPositive() {
compilationHelper
.addSourceLines(
"NullableContainer.java",
"package com.uber;",
"import javax.annotation.Nullable;",
"public interface NullableContainer<K, V> {",
" @Nullable public V get(K k);",
"}")
.addSourceLines(
"Test.java",
"package com.uber;",
"import javax.annotation.Nullable;",
"public class Test {",
" public void testEnhancedFor(NullableContainer<String, NullableContainer<Integer, Object>> c) {",
" if (c.get(\"KEY_STR\") != null && c.get(\"KEY_STR\").get(0) != null) {",
" // BUG: Diagnostic contains: dereferenced expression c.get(\"KEY_STR\").get(42)",
" c.get(\"KEY_STR\").get(42).toString();",
" }",
" }",
"}")
.doTest();
}

@Test
public void testVariablesInAccessPathsPositive() {
compilationHelper
.addSourceLines(
"NullableContainer.java",
"package com.uber;",
"import javax.annotation.Nullable;",
"public interface NullableContainer<K, V> {",
" @Nullable public V get(K k);",
"}")
.addSourceLines(
"Test.java",
"package com.uber;",
"import javax.annotation.Nullable;",
"public class Test {",
" private Integer intKey = 42;", // No guarantee it's a constant
" public void testEnhancedFor(NullableContainer<String, NullableContainer<Integer, Object>> c) {",
" if (c.get(\"KEY_STR\") != null && c.get(\"KEY_STR\").get(this.intKey) != null) {",
" // BUG: Diagnostic contains: dereferenced expression c.get(\"KEY_STR\").get",
" c.get(\"KEY_STR\").get(this.intKey).toString();",
" }",
" }",
"}")
.doTest();
}
}