Skip to content

Commit

Permalink
Merge pull request #149 from yandex/wp/codegen-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeffset committed May 8, 2024
2 parents 225482c + e06dad1 commit a554fc6
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.yandex.yatagan.codegen.impl

import com.squareup.javapoet.ClassName
import com.yandex.yatagan.codegen.poetry.CodeBuilder
import com.yandex.yatagan.codegen.poetry.TypeSpecBuilder
import com.yandex.yatagan.codegen.poetry.buildExpression
import com.yandex.yatagan.core.graph.BindingGraph
Expand Down Expand Up @@ -54,62 +55,47 @@ internal class SlotSwitchingGenerator @Inject constructor(

override fun generate(builder: TypeSpecBuilder) {
if (!isUsed) return

builder.method(FactoryMethodName) {
modifiers(/*package-private*/)
returnType(ClassName.OBJECT)
parameter(ClassName.INT, "slot")
if (maxSlotsPerSwitch > 1 && bindings.size > maxSlotsPerSwitch) {
// Strategy with two-level-nested switches

val outerSlotsCount = bindings.size / maxSlotsPerSwitch
controlFlow("switch(slot / $maxSlotsPerSwitch)") {
for(outerSlot in 0 .. outerSlotsCount) {
val nestedFactoryName = "$FactoryMethodName\$$outerSlot"
builder.method(nestedFactoryName) {
val chunks = bindings.chunked(if (maxSlotsPerSwitch > 1) maxSlotsPerSwitch else Int.MAX_VALUE)
when(val singleChunk = chunks.singleOrNull()) {
null -> controlFlow("switch(slot / $maxSlotsPerSwitch)") {
// Strategy with two-level-nested switches
chunks.forEachIndexed { chunkIndex, chunk ->
val nestedFactoryFunctionName = "$FactoryMethodName\$$chunkIndex"
builder.method(nestedFactoryFunctionName) {
modifiers(Modifier.PRIVATE)
returnType(ClassName.OBJECT)
parameter(ClassName.INT, "slot")
controlFlow("switch(slot)") {
val startIndex = outerSlot * maxSlotsPerSwitch
var nestedSlot = 0
while((startIndex + nestedSlot) < bindings.size && nestedSlot < maxSlotsPerSwitch) {
val binding = bindings[startIndex + nestedSlot]
+buildExpression {
+"case $nestedSlot: return "
binding.generateAccess(
builder = this,
inside = thisGraph,
isInsideInnerClass = false,
)
}
++nestedSlot
}
+"default: throw new %T()".formatCode(Names.AssertionError)
}
}
+buildExpression {
+"case $outerSlot: return %N(slot %L 100)".formatCode(nestedFactoryName, "%")
generateSwitchForChunk(chunk)
}
+"case $chunkIndex: return %N(slot %L $maxSlotsPerSwitch)"
.formatCode(nestedFactoryFunctionName, "%")
}
+"default: throw new %T()".formatCode(Names.AssertionError)
}
} else {
// Single switch statement
else -> generateSwitchForChunk(singleChunk)
}
}
}

controlFlow("switch(slot)") {
bindings.forEachIndexed { slot, binding ->
+buildExpression {
+"case $slot: return "
binding.generateAccess(
builder = this,
inside = thisGraph,
isInsideInnerClass = false,
)
}
}
+"default: throw new %T()".formatCode(Names.AssertionError)
private fun CodeBuilder.generateSwitchForChunk(chunk: List<Binding>) {
controlFlow("switch(slot)") {
chunk.forEachIndexed { slot, binding ->
+buildExpression {
+"case $slot: return "
binding.generateAccess(
builder = this,
inside = thisGraph,
isInsideInnerClass = false,
)
}
}
+"default: throw new %T()".formatCode(Names.AssertionError)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.yandex.yatagan.testing.tests

import androidx.room.compiler.processing.util.DiagnosticMessage
import androidx.room.compiler.processing.util.compiler.TestCompilationArguments
import androidx.room.compiler.processing.util.compiler.compile
import com.yandex.yatagan.generated.CompiledApiClasspath
Expand Down Expand Up @@ -93,7 +94,7 @@ abstract class CompileTestDriverBase private constructor(
return TestCompilationResult(
workingDir = workingDir,
runtimeClasspath = compilation.classpath + result.outputClasspath,
messageLog = result.diagnostics.values.flatten().joinToString(separator = "\n") { it.msg },
messageLog = result.diagnostics.values.flatten().joinToString(separator = "\n", transform = ::asString),
success = result.success,
generatedFiles = result.generatedSources,
)
Expand Down Expand Up @@ -152,6 +153,24 @@ abstract class CompileTestDriverBase private constructor(
?.bufferedReader()?.readText()?.ensureLineEndings() ?: ""
Assert.assertEquals(goldenOutput, strippedLog)

generatedFilesSubDir().takeIf { checkGoldenOutput }?.let {
val goldenFiles = GoldenSourceRegex.findAll(
javaClass.getResourceAsStream("/$goldenCodeResourcePath")
?.bufferedReader()?.readText()?.ensureLineEndings() ?: ""
).associateByTo(
destination = mutableMapOf(),
keySelector = { it.groupValues[1] },
valueTransform = { it.groupValues[2].trim() },
)

for (generatedFile in generatedFiles) {
val filePath = generatedFile.relativePath.replace(File.separatorChar, '/')
val goldenContents = goldenFiles.remove(filePath) ?: "<unexpected file>"
Assert.assertEquals("Generated file '${generatedFile.relativePath}' doesn't match the golden",
goldenContents, generatedFile.contents.trim())
}
}

if (success) {
// find runtime test
val classLoader = makeClassLoader(runtimeClasspath)
Expand All @@ -167,24 +186,6 @@ abstract class CompileTestDriverBase private constructor(
System.err.println(messageLog)
Assert.assertTrue("Compilation failed, yet expected output is blank", goldenOutput.isNotBlank())
}

generatedFilesSubDir().takeIf { checkGoldenOutput }?.let {
val goldenFiles = GoldenSourceRegex.findAll(
javaClass.getResourceAsStream("/$goldenCodeResourcePath")
?.bufferedReader()?.readText()?.ensureLineEndings() ?: ""
).associateByTo(
destination = mutableMapOf(),
keySelector = { it.groupValues[1] },
valueTransform = { it.groupValues[2].trim() },
)

for (generatedFile in generatedFiles) {
val filePath = generatedFile.relativePath.replace(File.separatorChar, '/')
val goldenContents = goldenFiles.remove(filePath) ?: "<unexpected file>"
Assert.assertEquals("Generated file '${generatedFile.relativePath}' doesn't match the golden",
goldenContents, generatedFile.contents.trim())
}
}
} finally {
generatedFilesSubDir()?.let { generatedFilesSubDir ->
for (generatedFile in generatedFiles) {
Expand Down Expand Up @@ -271,5 +272,14 @@ abstract class CompileTestDriverBase private constructor(

val isInUpdateGoldenMode: Boolean
get() = goldenSourceDirForUpdate != null

private fun asString(message: DiagnosticMessage): String = buildString {
append('[').append(message.kind.name.lowercase()).append(']')
append(' ')
message.location?.let {
append(it.source?.relativePath ?: "<unknown-source>").append(':').append(it.line).append(' ')
}
append(message.msg)
}
}
}

0 comments on commit a554fc6

Please sign in to comment.