diff --git a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java index 1f047413f2..cdbea612cb 100644 --- a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java +++ b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java @@ -25,6 +25,9 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.errorprone.util.ASTHelpers; +import com.sun.source.tree.LiteralTree; +import com.sun.source.tree.MethodInvocationTree; +import com.sun.source.tree.Tree; import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Type; import com.sun.tools.javac.code.Types; @@ -45,8 +48,9 @@ * Represents an extended notion of an access path, which we track for nullness. * *

Typically, access paths are of the form x.f.g.h, where x is a variable and f, g, and h are - * field names. Here, we also allow no-argument methods to appear in the access path, so it can be - * of the form x.f().g.h() in general. + * field names. Here, we also allow no-argument methods to appear in the access path, as well as + * method calls passed only statically constant parameters, so an AP can be of the form + * x.f().g.h([int_expr|string_expr]) in general. * *

We do not allow array accesses in access paths for the moment. */ @@ -54,20 +58,20 @@ public final class AccessPath { private final Root root; - private final ImmutableList elements; + private final ImmutableList elements; /** * if present, the argument to the map get() method call that is the final element of this path */ @Nullable private final AccessPath mapGetArgAccessPath; - AccessPath(Root root, List elements) { + AccessPath(Root root, List elements) { this.root = root; this.elements = ImmutableList.copyOf(elements); this.mapGetArgAccessPath = null; } - private AccessPath(Root root, List elements, AccessPath mapGetArgAccessPath) { + private AccessPath(Root root, List elements, AccessPath mapGetArgAccessPath) { this.root = root; this.elements = ImmutableList.copyOf(elements); this.mapGetArgAccessPath = mapGetArgAccessPath; @@ -96,7 +100,7 @@ static AccessPath fromVarDecl(VariableDeclarationNode node) { */ @Nullable static AccessPath fromFieldAccess(FieldAccessNode node) { - List elements = new ArrayList<>(); + List elements = new ArrayList<>(); Root root = populateElementsRec(node, elements); return (root != null) ? new AccessPath(root, elements) : null; } @@ -115,7 +119,7 @@ static AccessPath fromMethodCall(MethodInvocationNode node, @Nullable Types type @Nullable private static AccessPath fromVanillaMethodCall(MethodInvocationNode node) { - List elements = new ArrayList<>(); + List elements = new ArrayList<>(); Root root = populateElementsRec(node, elements); return (root != null) ? new AccessPath(root, elements) : null; } @@ -127,12 +131,12 @@ private static AccessPath fromVanillaMethodCall(MethodInvocationNode node) { */ @Nullable public static AccessPath fromBaseAndElement(Node base, Element element) { - List elements = new ArrayList<>(); + List elements = new ArrayList<>(); Root root = populateElementsRec(base, elements); if (root == null) { return null; } - elements.add(element); + elements.add(new AccessPathElement(element)); return new AccessPath(root, elements); } @@ -161,7 +165,7 @@ private static AccessPath fromMapGetCall(MethodInvocationNode node) { } MethodAccessNode target = node.getTarget(); Node receiver = target.getReceiver(); - List elements = new ArrayList<>(); + List elements = new ArrayList<>(); Root root = populateElementsRec(receiver, elements); if (root == null) { return null; @@ -202,8 +206,14 @@ public static AccessPath getAccessPathForNodeWithMapGet(Node node, @Nullable Typ } } + private static boolean isBoxingMethod(Symbol.MethodSymbol methodSymbol) { + return methodSymbol.isStatic() + && methodSymbol.getSimpleName().contentEquals("valueOf") + && methodSymbol.enclClass().packge().fullname.contentEquals("java.lang"); + } + @Nullable - private static Root populateElementsRec(Node node, List elements) { + private static Root populateElementsRec(Node node, List elements) { Root result; if (node instanceof FieldAccessNode) { FieldAccessNode fieldAccess = (FieldAccessNode) node; @@ -213,17 +223,48 @@ private static Root populateElementsRec(Node node, List elements) { } else { // instance field access result = populateElementsRec(fieldAccess.getReceiver(), elements); - elements.add(fieldAccess.getElement()); + elements.add(new AccessPathElement(fieldAccess.getElement())); } } else if (node instanceof MethodInvocationNode) { MethodInvocationNode invocation = (MethodInvocationNode) node; - // only support zero-argument methods - if (invocation.getArguments().size() > 0) { - return null; - } + AccessPathElement accessPathElement; MethodAccessNode accessNode = invocation.getTarget(); + if (invocation.getArguments().size() == 0) { + accessPathElement = new AccessPathElement(accessNode.getMethod()); + } else { + List constantArgumentValues = new ArrayList<>(); + for (Node argumentNode : invocation.getArguments()) { + Tree tree = argumentNode.getTree(); + if (tree == null) { + return null; // Not an AP + } else if (tree.getKind().equals(Tree.Kind.METHOD_INVOCATION)) { + // Check for boxing call + MethodInvocationTree methodInvocationTree = (MethodInvocationTree) tree; + if (methodInvocationTree.getArguments().size() == 1 + && isBoxingMethod(ASTHelpers.getSymbol(methodInvocationTree))) { + tree = methodInvocationTree.getArguments().get(0); + } + } + switch (tree.getKind()) { + case BOOLEAN_LITERAL: + case CHAR_LITERAL: + case DOUBLE_LITERAL: + case FLOAT_LITERAL: + case INT_LITERAL: + case LONG_LITERAL: + case STRING_LITERAL: + constantArgumentValues.add(((LiteralTree) tree).getValue().toString()); + break; + case NULL_LITERAL: + // Um, probably not? Cascade to default for now. + default: + return null; // Not an AP + } + } + accessPathElement = new AccessPathElement(accessNode.getMethod(), constantArgumentValues); + } result = populateElementsRec(accessNode.getReceiver(), elements); - elements.add(accessNode.getMethod()); + elements.add(accessPathElement); } else if (node instanceof LocalVariableNode) { result = new Root(((LocalVariableNode) node).getElement()); } else if (node instanceof ThisLiteralNode) { @@ -269,7 +310,7 @@ public Root getRoot() { return root; } - public ImmutableList getElements() { + public ImmutableList getElements() { return elements; } diff --git a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathElement.java b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathElement.java new file mode 100644 index 0000000000..574850c776 --- /dev/null +++ b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathElement.java @@ -0,0 +1,67 @@ +package com.uber.nullaway.dataflow; + +import com.google.common.collect.ImmutableList; +import java.util.Arrays; +import java.util.List; +import javax.annotation.Nullable; +import javax.lang.model.element.Element; + +/** + * Represents a (non-root) element of an AccessPath. + * + *

This is just a java Element (field, method, etc) in the access-path chain (e.g. f or g() in + * x.f.g()). Plus, optionally, a list of constant arguments, allowing access path elements for + * method calls with constant values (e.g. h(3) or k("STR_KEY") in x.h(3).g().k("STR_KEY")). + */ +public final class AccessPathElement { + private final Element javaElement; + @Nullable private final ImmutableList constantArguments; + + public AccessPathElement(Element javaElement, List constantArguments) { + this.javaElement = javaElement; + this.constantArguments = ImmutableList.copyOf(constantArguments); + } + + public AccessPathElement(Element javaElement) { + this.javaElement = javaElement; + this.constantArguments = null; + } + + public Element getJavaElement() { + return this.javaElement; + } + + public ImmutableList getConstantArguments() { + return this.constantArguments; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof AccessPathElement) { + AccessPathElement otherNode = (AccessPathElement) obj; + return this.javaElement.equals(otherNode.javaElement) + && (constantArguments == null + ? otherNode.constantArguments == null + : constantArguments.equals(otherNode.constantArguments)); + } else { + return false; + } + } + + @Override + public int hashCode() { + int result = javaElement.hashCode(); + result = 31 * result + (constantArguments != null ? constantArguments.hashCode() : 0); + return result; + } + + @Override + public String toString() { + return "APElement{" + + "javaElement=" + + javaElement.toString() + + ", constantArguments=" + + Arrays.deepToString(constantArguments.toArray()) + + '}'; + } +} diff --git a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessAnalysis.java b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessAnalysis.java index 40e20e902f..72d6000160 100644 --- a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessAnalysis.java +++ b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessAnalysis.java @@ -112,9 +112,9 @@ private Set getNonnullReceiverFields(NullnessStore nullnessResult) { Set result = new LinkedHashSet<>(); for (AccessPath ap : nonnullAccessPaths) { if (ap.getRoot().isReceiver()) { - ImmutableList elements = ap.getElements(); + ImmutableList elements = ap.getElements(); if (elements.size() == 1) { - Element elem = elements.get(0); + Element elem = elements.get(0).getJavaElement(); if (elem.getKind().equals(ElementKind.FIELD)) { result.add(elem); } diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/OptionalEmptinessHandler.java b/nullaway/src/main/java/com/uber/nullaway/handlers/OptionalEmptinessHandler.java index 454c940e4a..539fa218dd 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/OptionalEmptinessHandler.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/OptionalEmptinessHandler.java @@ -120,9 +120,11 @@ public boolean includeApInfoInSavedContext(AccessPath accessPath, VisitorState s if (accessPath.getElements().size() == 1) { AccessPath.Root root = accessPath.getRoot(); - if (!root.isReceiver() && (accessPath.getElements().get(0) instanceof Symbol.MethodSymbol)) { + if (!root.isReceiver() + && (accessPath.getElements().get(0).getJavaElement() instanceof Symbol.MethodSymbol)) { final Element e = root.getVarElement(); - final Symbol.MethodSymbol g = (Symbol.MethodSymbol) accessPath.getElements().get(0); + final Symbol.MethodSymbol g = + (Symbol.MethodSymbol) accessPath.getElements().get(0).getJavaElement(); return e.getKind().equals(ElementKind.LOCAL_VARIABLE) && optionalIsGetCall(g, state.getTypes()); } diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/RxNullabilityPropagator.java b/nullaway/src/main/java/com/uber/nullaway/handlers/RxNullabilityPropagator.java index 8f4e0843fe..cb516dabb4 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/RxNullabilityPropagator.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/RxNullabilityPropagator.java @@ -49,6 +49,7 @@ import com.uber.nullaway.NullabilityUtil; import com.uber.nullaway.Nullness; import com.uber.nullaway.dataflow.AccessPath; +import com.uber.nullaway.dataflow.AccessPathElement; import com.uber.nullaway.dataflow.AccessPathNullnessAnalysis; import com.uber.nullaway.dataflow.NullnessStore; import java.util.ArrayList; @@ -393,10 +394,10 @@ public void onMatchMethodReference( assert filterNullnessStore != null; for (AccessPath ap : filterNullnessStore.getAccessPathsWithValue(Nullness.NONNULL)) { // Find the access path corresponding to the current unbound method reference after binding - ImmutableList elements = ap.getElements(); + ImmutableList elements = ap.getElements(); if (elements.size() == 1) { // We only care for single method call chains (e.g. this.foo(), not this.f.bar()) - Element element = elements.get(0); + Element element = elements.get(0).getJavaElement(); if (!element.getKind().equals(ElementKind.METHOD)) { // We are only looking for method APs continue; diff --git a/nullaway/src/test/java/com/uber/nullaway/NullAwayTest.java b/nullaway/src/test/java/com/uber/nullaway/NullAwayTest.java index dc41f53c78..361dd272c1 100644 --- a/nullaway/src/test/java/com/uber/nullaway/NullAwayTest.java +++ b/nullaway/src/test/java/com/uber/nullaway/NullAwayTest.java @@ -584,7 +584,6 @@ public void testThriftIsSet() { " }", " java.util.List l = new java.util.ArrayList<>();", " if (l.get(0).isSetId()) {", - " // BUG: Diagnostic contains: dereferenced expression l.get(0).getId()", " l.get(0).getId().hashCode();", " }", " }", @@ -1431,4 +1430,89 @@ public void testEnhancedFor() { "}") .doTest(); } + + @Test + public void testConstantsInAccessPathsNegative() { + compilationHelper + .addSourceLines( + "NullableContainer.java", + "package com.uber;", + "import javax.annotation.Nullable;", + "public interface NullableContainer {", + " @Nullable public V get(K k);", + "}") + .addSourceLines( + "Test.java", + "package com.uber;", + "import javax.annotation.Nullable;", + "public class Test {", + " public void testSingleStringCheck(NullableContainer c) {", + " if (c.get(\"KEY_STR\") != null) {", + " c.get(\"KEY_STR\").toString(); // is safe", + " }", + " }", + " public void testSingleIntCheck(NullableContainer c) {", + " if (c.get(42) != null) {", + " c.get(42).toString(); // is safe", + " }", + " }", + " public void testMultipleChecks(NullableContainer> c) {", + " if (c.get(\"KEY_STR\") != null && c.get(\"KEY_STR\").get(42) != null) {", + " c.get(\"KEY_STR\").get(42).toString(); // is safe", + " }", + " }", + "}") + .doTest(); + } + + @Test + public void testConstantsInAccessPathsPositive() { + compilationHelper + .addSourceLines( + "NullableContainer.java", + "package com.uber;", + "import javax.annotation.Nullable;", + "public interface NullableContainer {", + " @Nullable public V get(K k);", + "}") + .addSourceLines( + "Test.java", + "package com.uber;", + "import javax.annotation.Nullable;", + "public class Test {", + " public void testEnhancedFor(NullableContainer> c) {", + " if (c.get(\"KEY_STR\") != null && c.get(\"KEY_STR\").get(0) != null) {", + " // BUG: Diagnostic contains: dereferenced expression c.get(\"KEY_STR\").get(42)", + " c.get(\"KEY_STR\").get(42).toString();", + " }", + " }", + "}") + .doTest(); + } + + @Test + public void testVariablesInAccessPathsPositive() { + compilationHelper + .addSourceLines( + "NullableContainer.java", + "package com.uber;", + "import javax.annotation.Nullable;", + "public interface NullableContainer {", + " @Nullable public V get(K k);", + "}") + .addSourceLines( + "Test.java", + "package com.uber;", + "import javax.annotation.Nullable;", + "public class Test {", + " private Integer intKey = 42;", // No guarantee it's a constant + " public void testEnhancedFor(NullableContainer> c) {", + " if (c.get(\"KEY_STR\") != null && c.get(\"KEY_STR\").get(this.intKey) != null) {", + " // BUG: Diagnostic contains: dereferenced expression c.get(\"KEY_STR\").get", + " c.get(\"KEY_STR\").get(this.intKey).toString();", + " }", + " }", + "}") + .doTest(); + } }