Skip to content

Commit

Permalink
Remove usages of ClassName.bestGuess(). There are still remaining usa…
Browse files Browse the repository at this point in the history
…ges for roots, but this will help with modules inside packages that don't follow normal capitalization conventions.

Issue #3329.

RELNOTES=Address part of #3329 where modules in a package with non-standard capitalization could cause an error.
PiperOrigin-RevId: 441024848
  • Loading branch information
Chang-Eric authored and Dagger Team committed Apr 12, 2022
1 parent c9b5df5 commit af76259
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 138 deletions.
Expand Up @@ -40,7 +40,8 @@ import org.objectweb.asm.Type
import org.slf4j.Logger

/** Aggregates Hilt dependencies. */
internal class Aggregator private constructor(
internal class Aggregator
private constructor(
private val logger: Logger,
private val asmApiVersion: Int,
) {
Expand Down Expand Up @@ -137,8 +138,8 @@ internal class Aggregator private constructor(
aggregatedRoots.add(
AggregatedRootIr(
fqName = annotatedClassName,
root = rootClass.toClassName(),
originatingRoot = originatingRootClass.toClassName(),
root = ClassName.bestGuess(rootClass),
originatingRoot = ClassName.bestGuess(originatingRootClass),
rootAnnotation = rootAnnotationClassName.toClassName()
)
)
Expand All @@ -159,10 +160,7 @@ internal class Aggregator private constructor(

override fun visitEnd() {
processedRoots.add(
ProcessedRootSentinelIr(
fqName = annotatedClassName,
roots = rootClasses.map { it.toClassName() }
)
ProcessedRootSentinelIr(fqName = annotatedClassName, roots = rootClasses)
)
super.visitEnd()
}
Expand All @@ -181,10 +179,7 @@ internal class Aggregator private constructor(

override fun visitEnd() {
defineComponentDeps.add(
DefineComponentClassesIr(
fqName = annotatedClassName,
component = componentClass.toClassName()
)
DefineComponentClassesIr(fqName = annotatedClassName, component = componentClass())
)
super.visitEnd()
}
Expand Down Expand Up @@ -239,14 +234,10 @@ internal class Aggregator private constructor(

override fun visitArray(name: String): AnnotationVisitor? {
return when (name) {
"components" ->
visitValue { value -> componentClasses.add(value as String) }
"replaces" ->
visitValue { value -> replacesClasses.add(value as String) }
"modules" ->
visitValue { value -> moduleClass = value as String }
"entryPoints" ->
visitValue { value -> entryPoint = value as String }
"components" -> visitValue { value -> componentClasses.add(value as String) }
"replaces" -> visitValue { value -> replacesClasses.add(value as String) }
"modules" -> visitValue { value -> moduleClass = value as String }
"entryPoints" -> visitValue { value -> entryPoint = value as String }
"componentEntryPoints" ->
visitValue { value -> componentEntryPoint = value as String }
else -> super.visitArray(name)
Expand All @@ -257,12 +248,12 @@ internal class Aggregator private constructor(
aggregatedDeps.add(
AggregatedDepsIr(
fqName = annotatedClassName,
components = componentClasses.map { it.toClassName() },
test = testClass?.toClassName(),
replaces = replacesClasses.map { it.toClassName() },
module = moduleClass?.toClassName(),
entryPoint = entryPoint?.toClassName(),
componentEntryPoint = componentEntryPoint?.toClassName()
components = componentClasses,
test = testClass,
replaces = replacesClasses,
module = moduleClass,
entryPoint = entryPoint,
componentEntryPoint = componentEntryPoint
)
)
super.visitEnd()
Expand Down Expand Up @@ -315,8 +306,8 @@ internal class Aggregator private constructor(
uninstallModulesDeps.add(
AggregatedUninstallModulesIr(
fqName = annotatedClassName,
test = testClass.toClassName(),
uninstallModules = uninstallModulesClasses.map { it.toClassName() }
test = testClass,
uninstallModules = uninstallModulesClasses
)
)
super.visitEnd()
Expand All @@ -338,7 +329,7 @@ internal class Aggregator private constructor(
earlyEntryPointDeps.add(
AggregatedEarlyEntryPointIr(
fqName = annotatedClassName,
earlyEntryPoint = earlyEntryPointClass.toClassName()
earlyEntryPoint = earlyEntryPointClass
)
)
super.visitEnd()
Expand Down Expand Up @@ -372,48 +363,43 @@ internal class Aggregator private constructor(

private fun visitFile(file: File) {
when {
file.isJarFile() -> ZipInputStream(file.inputStream()).forEachZipEntry { inputStream, entry ->
if (entry.isClassFile()) {
visitClass(inputStream)
file.isJarFile() ->
ZipInputStream(file.inputStream()).forEachZipEntry { inputStream, entry ->
if (entry.isClassFile()) {
visitClass(inputStream)
}
}
}
file.isClassFile() -> file.inputStream().use { visitClass(it) }
else -> logger.debug("Don't know how to process file: $file")
}
}

private fun visitClass(classFileInputStream: InputStream) {
ClassReader(classFileInputStream).accept(
classVisitor,
ClassReader.SKIP_CODE and ClassReader.SKIP_DEBUG and ClassReader.SKIP_FRAMES
)
ClassReader(classFileInputStream)
.accept(
classVisitor,
ClassReader.SKIP_CODE and ClassReader.SKIP_DEBUG and ClassReader.SKIP_FRAMES
)
}

companion object {
fun from(
logger: Logger,
asmApiVersion: Int,
input: Iterable<File>
) = Aggregator(logger, asmApiVersion).apply { process(input) }
fun from(logger: Logger, asmApiVersion: Int, input: Iterable<File>) =
Aggregator(logger, asmApiVersion).apply { process(input) }

// Converts this Type to a ClassName, used instead of ClassName.bestGuess() because ASM class
// names are based off descriptors and uses 'reflection' naming, i.e. inner classes are split
// by '$' instead of '.'
fun Type.toClassName(): ClassName {
val binaryName = this.className
val packageNameEndIndex = binaryName.lastIndexOf('.')
val packageName = if (packageNameEndIndex != -1) {
binaryName.substring(0, packageNameEndIndex)
} else {
""
}
val packageName =
if (packageNameEndIndex != -1) {
binaryName.substring(0, packageNameEndIndex)
} else {
""
}
val shortNames = binaryName.substring(packageNameEndIndex + 1).split('$')
return ClassName.get(packageName, shortNames.first(), *shortNames.drop(1).toTypedArray())
}

// Converts this String representing the canonical name of a class to a ClassName.
fun String.toClassName(): ClassName {
return ClassName.bestGuess(this)
}
}
}
Expand Up @@ -88,21 +88,24 @@ public static AggregatedDepsIr toIr(AggregatedDepsMetadata metadata) {
ClassName.get(metadata.aggregatingElement()),
metadata.componentElements().stream()
.map(ClassName::get)
.map(ClassName::canonicalName)
.collect(Collectors.toList()),
metadata.testElement()
.map(ClassName::get)
.map(ClassName::canonicalName)
.orElse(null),
metadata.replacedDependencies().stream()
.map(ClassName::get)
.map(ClassName::canonicalName)
.collect(Collectors.toList()),
metadata.dependencyType() == DependencyType.MODULE
? ClassName.get(metadata.dependency())
? ClassName.get(metadata.dependency()).canonicalName()
: null,
metadata.dependencyType() == DependencyType.ENTRY_POINT
? ClassName.get(metadata.dependency())
? ClassName.get(metadata.dependency()).canonicalName()
: null,
metadata.dependencyType() == DependencyType.COMPONENT_ENTRY_POINT
? ClassName.get(metadata.dependency())
? ClassName.get(metadata.dependency()).canonicalName()
: null);
}

Expand Down
Expand Up @@ -112,6 +112,6 @@ private static DefineComponentClassesMetadata create(TypeElement element, Elemen
public static DefineComponentClassesIr toIr(DefineComponentClassesMetadata metadata) {
return new DefineComponentClassesIr(
ClassName.get(metadata.aggregatingElement()),
ClassName.get(metadata.element()));
ClassName.get(metadata.element()).canonicalName());
}
}
Expand Up @@ -66,7 +66,7 @@ public static ImmutableSet<AggregatedEarlyEntryPointMetadata> from(
public static AggregatedEarlyEntryPointIr toIr(AggregatedEarlyEntryPointMetadata metadata) {
return new AggregatedEarlyEntryPointIr(
ClassName.get(metadata.aggregatingElement()),
ClassName.get(metadata.earlyEntryPoint()));
ClassName.get(metadata.earlyEntryPoint()).canonicalName());
}

private static AggregatedEarlyEntryPointMetadata create(TypeElement element, Elements elements) {
Expand Down
Expand Up @@ -59,7 +59,10 @@ static ImmutableSet<ProcessedRootSentinelMetadata> from(Elements elements) {
static ProcessedRootSentinelIr toIr(ProcessedRootSentinelMetadata metadata) {
return new ProcessedRootSentinelIr(
ClassName.get(metadata.aggregatingElement()),
metadata.rootElements().stream().map(ClassName::get).collect(Collectors.toList())
metadata.rootElements().stream()
.map(ClassName::get)
.map(ClassName::canonicalName)
.collect(Collectors.toList())
);
}

Expand Down
Expand Up @@ -29,8 +29,8 @@ object AggregatedRootIrValidator {
): Set<AggregatedRootIr> {
val processedRootNames = processedRoots.flatMap { it.roots }.toSet()
val rootsToProcess =
aggregatedRoots.filterNot { processedRootNames.contains(it.root) }.sortedBy {
it.root.toString()
aggregatedRoots.filterNot { processedRootNames.contains(it.root.canonicalName()) }.sortedBy {
it.root.canonicalName()
}
val testRootsToProcess = rootsToProcess.filter { it.isTestRoot }
val appRootsToProcess = rootsToProcess - testRootsToProcess
Expand All @@ -53,7 +53,9 @@ object AggregatedRootIrValidator {
// Perform validation across roots previous compilation units.
if (!isCrossCompilationRootValidationDisabled) {
val alreadyProcessedTestRoots =
aggregatedRoots.filter { it.isTestRoot && processedRootNames.contains(it.root) }
aggregatedRoots.filter {
it.isTestRoot && processedRootNames.contains(it.root.canonicalName())
}
// TODO(b/185742783): Add an explanation or link to docs to explain why we're forbidding this.
if (alreadyProcessedTestRoots.isNotEmpty() && rootsToProcess.isNotEmpty()) {
throw InvalidRootsException(
Expand All @@ -67,7 +69,7 @@ object AggregatedRootIrValidator {

val alreadyProcessedAppRoots =
aggregatedRoots.filter {
!it.isTestRoot && processedRootNames.contains(it.root)
!it.isTestRoot && processedRootNames.contains(it.root.canonicalName())
}
if (alreadyProcessedAppRoots.isNotEmpty() && appRootsToProcess.isNotEmpty()) {
throw InvalidRootsException(
Expand Down

0 comments on commit af76259

Please sign in to comment.