Skip to content

Commit

Permalink
Support generics and inner classes in @optics (#2776)
Browse files Browse the repository at this point in the history
Co-authored-by: Imran Malic Settuba <46971368+i-walker@users.noreply.github.com>
  • Loading branch information
serras and i-walker committed Aug 3, 2022
1 parent 7bef773 commit 88f7b06
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 61 deletions.
Expand Up @@ -38,12 +38,6 @@ class OpticsProcessor(private val codegen: CodeGenerator, private val logger: KS
return
}

// check that it does not have type arguments
if (klass.typeParameters.isNotEmpty()) {
logger.error(klass.qualifiedNameOrSimpleName.typeParametersErrorMessage, klass)
return
}

// check that the companion object exists
if (klass.companionObject == null) {
logger.error(klass.qualifiedNameOrSimpleName.noCompanion, klass)
Expand Down
Expand Up @@ -2,25 +2,34 @@ package arrow.optics.plugin.internals

import arrow.optics.plugin.companionObject
import com.google.devtools.ksp.getVisibility
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSName
import com.google.devtools.ksp.symbol.Visibility
import com.google.devtools.ksp.symbol.*
import java.util.Locale

data class ADT(val pckg: KSName, val declaration: KSClassDeclaration, val targets: List<Target>) {
val sourceClassName = declaration.qualifiedNameOrSimpleName
val sourceName = declaration.simpleName.asString().replaceFirstChar { it.lowercase(Locale.getDefault()) }
val simpleName = declaration.simpleName.asString()
val simpleName = declaration.nameWithParentClass
val packageName = pckg.asString()
val visibilityModifierName = when (declaration.companionObject?.getVisibility()) {
Visibility.INTERNAL -> "internal"
else -> "public"
}
val typeParameters: List<String> = declaration.typeParameters.map { it.simpleName.asString() }
val angledTypeParameters: String = when {
typeParameters.isEmpty() -> ""
else -> "<${typeParameters.joinToString(separator = ",")}>"
}

operator fun Snippet.plus(snippet: Snippet): Snippet =
copy(imports = imports + snippet.imports, content = "$content\n${snippet.content}")
}

val KSClassDeclaration.nameWithParentClass: String
get() = when (val parent = parentDeclaration) {
is KSClassDeclaration -> parent.nameWithParentClass + "." + simpleName.asString()
else -> simpleName.asString()
}

enum class OpticsTarget {
ISO,
LENS,
Expand Down Expand Up @@ -61,27 +70,46 @@ typealias NullableFocus = Focus.Nullable
sealed class Focus {

companion object {
operator fun invoke(fullName: String, paramName: String): Focus =
operator fun invoke(fullName: String, paramName: String, refinedType: KSType? = null): Focus =
when {
fullName.endsWith("?") -> Nullable(fullName, paramName)
fullName.startsWith("`arrow`.`core`.`Option`") -> Option(fullName, paramName)
else -> NonNull(fullName, paramName)
fullName.endsWith("?") -> Nullable(fullName, paramName, refinedType)
fullName.startsWith("`arrow`.`core`.`Option`") -> Option(fullName, paramName, refinedType)
else -> NonNull(fullName, paramName, refinedType)
}
}

abstract val className: String
abstract val paramName: String

data class Nullable(override val className: String, override val paramName: String) : Focus() {
// only used for type-refining prisms
abstract val refinedType: KSType?

val refinedArguments: List<String>
get() = refinedType?.arguments?.filter {
it.type?.resolve()?.declaration is KSTypeParameter
}?.map { it.qualifiedString() }.orEmpty()

data class Nullable(
override val className: String,
override val paramName: String,
override val refinedType: KSType?
) : Focus() {
val nonNullClassName = className.dropLast(1)
}

data class Option(override val className: String, override val paramName: String) : Focus() {
data class Option(
override val className: String,
override val paramName: String,
override val refinedType: KSType?
) : Focus() {
val nestedClassName =
Regex("`arrow`.`core`.`Option`<(.*)>$").matchEntire(className)!!.groupValues[1]
}

data class NonNull(override val className: String, override val paramName: String) : Focus()
data class NonNull(
override val className: String,
override val paramName: String,
override val refinedType: KSType?
) : Focus()
}

const val Lens = "arrow.optics.Lens"
Expand Down
Expand Up @@ -22,7 +22,8 @@ fun generatePrismDsl(ele: ADT, isoOptic: SealedClassDsl): Snippet =
)

private fun processLensSyntax(ele: ADT, foci: List<Focus>): String =
foci.joinToString(separator = "\n") { focus ->
if (ele.typeParameters.isEmpty()) {
foci.joinToString(separator = "\n") { focus ->
"""
|${ele.visibilityModifierName} inline val <S> $Iso<S, ${ele.sourceClassName}>.${focus.lensParamName()}: $Lens<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()}
|${ele.visibilityModifierName} inline val <S> $Lens<S, ${ele.sourceClassName}>.${focus.lensParamName()}: $Lens<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()}
Expand All @@ -34,17 +35,36 @@ private fun processLensSyntax(ele: ADT, foci: List<Focus>): String =
|${ele.visibilityModifierName} inline val <S> $Fold<S, ${ele.sourceClassName}>.${focus.lensParamName()}: $Fold<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()}
|${ele.visibilityModifierName} inline val <S> $Every<S, ${ele.sourceClassName}>.${focus.lensParamName()}: $Every<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.lensParamName()}
|""".trimMargin()
}
} else {
val sourceClassNameWithParams = "${ele.sourceClassName}${ele.angledTypeParameters}"
val joinedTypeParams = ele.typeParameters.joinToString(separator=",")
foci.joinToString(separator = "\n") { focus ->
"""
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Iso<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Lens<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Lens<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Lens<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Optional<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Optional<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Prism<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Optional<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Getter<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Getter<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Setter<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Setter<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Traversal<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Traversal<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Fold<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Fold<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Every<S, $sourceClassNameWithParams>.${focus.lensParamName()}(): $Every<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.lensParamName()}()
|""".trimMargin()
}
}

private fun processOptionalSyntax(ele: ADT, optic: DataClassDsl) =
optic.foci.filterNot { it is NonNullFocus }.joinToString(separator = "\n") { focus ->
private fun processOptionalSyntax(ele: ADT, optic: DataClassDsl): String {
val sourceClassNameWithParams = "${ele.sourceClassName}${ele.angledTypeParameters}"
val joinedTypeParams = ele.typeParameters.joinToString(separator=",")
return optic.foci.filterNot { it is NonNullFocus }.joinToString(separator = "\n") { focus ->
val targetClassName =
when (focus) {
is NullableFocus -> focus.nonNullClassName
is OptionFocus -> focus.nestedClassName
is NonNullFocus -> ""
is Focus.Nullable -> focus.nonNullClassName
is Focus.Option -> focus.nestedClassName
is Focus.NonNull -> ""
}

if (ele.typeParameters.isEmpty()) {
"""
|${ele.visibilityModifierName} inline val <S> $Iso<S, ${ele.sourceClassName}>.${focus.paramName}: $Optional<S, $targetClassName> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Lens<S, ${ele.sourceClassName}>.${focus.paramName}: $Optional<S, $targetClassName> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
Expand All @@ -55,10 +75,24 @@ private fun processOptionalSyntax(ele: ADT, optic: DataClassDsl) =
|${ele.visibilityModifierName} inline val <S> $Fold<S, ${ele.sourceClassName}>.${focus.paramName}: $Fold<S, $targetClassName> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Every<S, ${ele.sourceClassName}>.${focus.paramName}: $Every<S, $targetClassName> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
|""".trimMargin()
} else {
"""
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Iso<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Lens<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Optional<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Prism<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Setter<S, $sourceClassNameWithParams>.${focus.paramName}(): $Setter<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Traversal<S, $sourceClassNameWithParams>.${focus.paramName}(): $Traversal<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Fold<S, $sourceClassNameWithParams>.${focus.paramName}(): $Fold<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Every<S, $sourceClassNameWithParams>.${focus.paramName}(): $Every<S, $targetClassName> = this + ${ele.sourceClassName}.${focus.paramName}()
|""".trimMargin()
}
}
}

private fun processPrismSyntax(ele: ADT, dsl: SealedClassDsl): String =
dsl.foci.joinToString(separator = "\n\n") { focus ->
if (ele.typeParameters.isEmpty()) {
dsl.foci.joinToString(separator = "\n\n") { focus ->
"""
|${ele.visibilityModifierName} inline val <S> $Iso<S, ${ele.sourceClassName}>.${focus.paramName}: $Prism<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Lens<S, ${ele.sourceClassName}>.${focus.paramName}: $Optional<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
Expand All @@ -69,4 +103,23 @@ private fun processPrismSyntax(ele: ADT, dsl: SealedClassDsl): String =
|${ele.visibilityModifierName} inline val <S> $Fold<S, ${ele.sourceClassName}>.${focus.paramName}: $Fold<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
|${ele.visibilityModifierName} inline val <S> $Every<S, ${ele.sourceClassName}>.${focus.paramName}: $Every<S, ${focus.className}> inline get() = this + ${ele.sourceClassName}.${focus.paramName}
|""".trimMargin()
}
} else {
dsl.foci.joinToString(separator = "\n\n") { focus ->
val sourceClassNameWithParams = focus.refinedType?.qualifiedString() ?: "${ele.sourceClassName}${ele.angledTypeParameters}"
val joinedTypeParams = when {
focus.refinedArguments.isEmpty() -> ""
else -> focus.refinedArguments.joinToString(separator=",")
}
"""
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Iso<S, $sourceClassNameWithParams>.${focus.paramName}(): $Prism<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Lens<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Optional<S, $sourceClassNameWithParams>.${focus.paramName}(): $Optional<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Prism<S, $sourceClassNameWithParams>.${focus.paramName}(): $Prism<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Setter<S, $sourceClassNameWithParams>.${focus.paramName}(): $Setter<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Traversal<S, $sourceClassNameWithParams>.${focus.paramName}(): $Traversal<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Fold<S, $sourceClassNameWithParams>.${focus.paramName}(): $Fold<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|${ele.visibilityModifierName} inline fun <S,$joinedTypeParams> $Every<S, $sourceClassNameWithParams>.${focus.paramName}(): $Every<S, ${focus.className}> = this + ${ele.sourceClassName}.${focus.paramName}()
|""".trimMargin()
}
}
Expand Up @@ -63,9 +63,17 @@ private fun processElement(iso: ADT, target: Target): String {
"tuple: ${focusType()} -> ${(foci.indices).joinToString(prefix = "${iso.sourceClassName}(", postfix = ")", transform = { "tuple.${letters[it]}" })}"
}

val sourceClassNameWithParams = "${iso.sourceClassName}${iso.angledTypeParameters}"
val firstLine = when {
iso.typeParameters.isEmpty() ->
"${iso.visibilityModifierName} inline val ${iso.sourceClassName}.Companion.iso: $Iso<${iso.sourceClassName}, ${focusType()}> inline get()"
else ->
"${iso.visibilityModifierName} inline fun ${iso.angledTypeParameters} ${iso.sourceClassName}.Companion.iso(): $Iso<$sourceClassNameWithParams, ${focusType()}>"
}

return """
|${iso.visibilityModifierName} inline val ${iso.sourceClassName}.Companion.iso: $Iso<${iso.sourceClassName}, ${focusType()}> inline get()= $Iso(
| get = { ${iso.sourceName}: ${iso.sourceClassName} -> ${tupleConstructor()} },
|$firstLine = $Iso(
| get = { ${iso.sourceName}: $sourceClassNameWithParams -> ${tupleConstructor()} },
| reverseGet = { ${classConstructorFromTuple()} }
|)
|""".trimMargin()
Expand Down
Expand Up @@ -20,17 +20,24 @@ private fun String.toUpperCamelCase(): String =
}
)

private fun processElement(adt: ADT, foci: List<Focus>): String =
foci.joinToString(separator = "\n") { focus ->
private fun processElement(adt: ADT, foci: List<Focus>): String {
val sourceClassNameWithParams = "${adt.sourceClassName}${adt.angledTypeParameters}"
return foci.joinToString(separator = "\n") { focus ->
val firstLine = when {
adt.typeParameters.isEmpty() ->
"${adt.visibilityModifierName} inline val ${adt.sourceClassName}.Companion.${focus.lensParamName()}: $Lens<${adt.sourceClassName}, ${focus.className}> inline get()"
else ->
"${adt.visibilityModifierName} inline fun ${adt.angledTypeParameters} ${adt.sourceClassName}.Companion.${focus.lensParamName()}(): $Lens<$sourceClassNameWithParams, ${focus.className}>"
}
"""
|${adt.visibilityModifierName} inline val ${adt.sourceClassName}.Companion.${focus.lensParamName()}: $Lens<${adt.sourceClassName}, ${focus.className}> inline get()= $Lens(
| get = { ${adt.sourceName}: ${adt.sourceClassName} -> ${adt.sourceName}.${
|$firstLine = $Lens(
| get = { ${adt.sourceName}: $sourceClassNameWithParams -> ${adt.sourceName}.${
focus.paramName.plusIfNotBlank(
prefix = "`",
postfix = "`"
)
} },
| set = { ${adt.sourceName}: ${adt.sourceClassName}, value: ${focus.className} -> ${adt.sourceName}.copy(${
| set = { ${adt.sourceName}: $sourceClassNameWithParams, value: ${focus.className} -> ${adt.sourceName}.copy(${
focus.paramName.plusIfNotBlank(
prefix = "`",
postfix = "`"
Expand All @@ -39,6 +46,7 @@ private fun processElement(adt: ADT, foci: List<Focus>): String =
|)
|""".trimMargin()
}
}

fun Focus.lensParamName(): String =
when (this) {
Expand Down
Expand Up @@ -11,8 +11,23 @@ internal fun generateOptionals(ele: ADT, target: OptionalTarget) =

private fun processElement(ele: ADT, foci: List<Focus>): String =
foci.joinToString(separator = "\n") { focus ->

val targetClassName = when (focus) {
is NullableFocus -> focus.nonNullClassName
is OptionFocus -> focus.nestedClassName
is NonNullFocus -> return@joinToString ""
}

val sourceClassNameWithParams = "${ele.sourceClassName}${ele.angledTypeParameters}"
val firstLine = when {
ele.typeParameters.isEmpty() ->
"${ele.visibilityModifierName} inline val ${ele.sourceClassName}.Companion.${focus.paramName}: $Optional<${ele.sourceClassName}, $targetClassName> inline get()"
else ->
"${ele.visibilityModifierName} inline fun ${ele.angledTypeParameters} ${ele.sourceClassName}.Companion.${focus.paramName}(): $Optional<$sourceClassNameWithParams, $targetClassName>"
}

fun getOrModifyF(toNullable: String = "") =
"{ ${ele.sourceName}: ${ele.sourceClassName} -> ${ele.sourceName}.${
"{ ${ele.sourceName}: $sourceClassNameWithParams -> ${ele.sourceName}.${
focus.paramName.plusIfNotBlank(
prefix = "`",
postfix = "`"
Expand All @@ -21,18 +36,17 @@ private fun processElement(ele: ADT, foci: List<Focus>): String =
fun setF(fromNullable: String = "") =
"${ele.sourceName}.copy(${focus.paramName.plusIfNotBlank(prefix = "`", postfix = "`")} = value$fromNullable)"

val (targetClassName, getOrModify, set) =
val (getOrModify, set) =
when (focus) {
is NullableFocus -> Triple(focus.nonNullClassName, getOrModifyF(), setF())
is OptionFocus ->
Triple(focus.nestedClassName, getOrModifyF(".orNull()"), setF(".toOption()"))
is NullableFocus -> Pair(getOrModifyF(), setF())
is OptionFocus -> Pair(getOrModifyF(".orNull()"), setF(".toOption()"))
is NonNullFocus -> return@joinToString ""
}

"""
|${ele.visibilityModifierName} inline val ${ele.sourceClassName}.Companion.${focus.paramName}: $Optional<${ele.sourceClassName}, $targetClassName> inline get()= $Optional(
|$firstLine = $Optional(
| getOrModify = $getOrModify,
| set = { ${ele.sourceName}: ${ele.sourceClassName}, value: $targetClassName -> $set }
| set = { ${ele.sourceName}: $sourceClassNameWithParams, value: $targetClassName -> $set }
|)
|""".trimMargin()
}

0 comments on commit 88f7b06

Please sign in to comment.