Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework the API expressed by DgsExecutionResult #1298

Merged
merged 1 commit into from Oct 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -16,25 +16,31 @@

package com.netflix.graphql.dgs.example.datafetcher;

import com.netflix.graphql.dgs.internal.DgsExecutionResult;
import com.netflix.graphql.dgs.DgsExecutionResult;
import graphql.ExecutionResult;
import graphql.execution.instrumentation.InstrumentationState;
import graphql.execution.instrumentation.SimpleInstrumentation;
import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters;
import org.jetbrains.annotations.NotNull;
import org.springframework.http.HttpHeaders;
import org.springframework.stereotype.Component;

import java.util.concurrent.CompletableFuture;

@Component
public class MyInstrumentation extends SimpleInstrumentation {
@NotNull
@Override
public CompletableFuture<ExecutionResult> instrumentExecutionResult(ExecutionResult executionResult, InstrumentationExecutionParameters parameters) {
public CompletableFuture<ExecutionResult> instrumentExecutionResult(ExecutionResult executionResult,
InstrumentationExecutionParameters parameters,
InstrumentationState state) {
HttpHeaders responseHeaders = new HttpHeaders();
responseHeaders.add("myHeader", "hello");

return super.instrumentExecutionResult(new DgsExecutionResult(
executionResult,
responseHeaders
), parameters);
return super.instrumentExecutionResult(
DgsExecutionResult.builder().executionResult(executionResult).headers(responseHeaders).build(),
parameters,
state
);
}
}
Expand Up @@ -20,7 +20,7 @@ import com.fasterxml.jackson.core.JsonParseException
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.exc.MismatchedInputException
import com.fasterxml.jackson.module.kotlin.readValue
import com.netflix.graphql.dgs.internal.DgsExecutionResult
import com.netflix.graphql.dgs.DgsExecutionResult
import com.netflix.graphql.dgs.reactive.DgsReactiveQueryExecutor
import graphql.ExecutionResult
import org.slf4j.Logger
Expand Down Expand Up @@ -77,12 +77,12 @@ class DefaultDgsWebfluxHttpHandler(
return executionResult.flatMap { result ->
val dgsExecutionResult = when (result) {
is DgsExecutionResult -> result
else -> DgsExecutionResult(result)
else -> DgsExecutionResult.builder().executionResult(result).build()
}

ServerResponse
.status(dgsExecutionResult.status)
.headers { it.addAll(dgsExecutionResult.headers) }
.headers { it.addAll(dgsExecutionResult.headers()) }
.bodyValue(dgsExecutionResult.toSpecification())
}.onErrorResume { ex ->
when (ex) {
Expand Down
Expand Up @@ -21,8 +21,8 @@ import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.exc.MismatchedInputException
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import com.fasterxml.jackson.module.kotlin.readValue
import com.netflix.graphql.dgs.DgsExecutionResult
import com.netflix.graphql.dgs.DgsQueryExecutor
import com.netflix.graphql.dgs.internal.DgsExecutionResult
import com.netflix.graphql.dgs.internal.utils.MultipartVariableMapper
import com.netflix.graphql.dgs.internal.utils.TimeTracer
import graphql.execution.reactive.SubscriptionPublisher
Expand Down Expand Up @@ -228,7 +228,7 @@ open class DgsRestController(

return when (executionResult) {
is DgsExecutionResult -> executionResult.toSpringResponse()
else -> DgsExecutionResult(executionResult).toSpringResponse()
else -> DgsExecutionResult.builder().executionResult(executionResult).build().toSpringResponse()
}
}
}
Expand Up @@ -199,7 +199,10 @@ class DgsRestControllerTest {
any(),
any()
)
} returns ExecutionResultImpl.newExecutionResult().data(mapOf(Pair("hello", "hello"))).extensions(mutableMapOf(Pair(DgsRestController.DGS_RESPONSE_HEADERS_KEY, mapOf(Pair("myHeader", "hello")))) as Map<Any, Any>?).build()
} returns ExecutionResultImpl.newExecutionResult()
.data(mapOf("hello" to "hello"))
.extensions(mutableMapOf(DgsRestController.DGS_RESPONSE_HEADERS_KEY to mapOf("myHeader" to "hello")) as Map<Any, Any>?)
.build()

val result = DgsRestController(dgsQueryExecutor).graphql(requestBody.toByteArray(), null, null, null, httpHeaders, webRequest)
assertThat(result.headers["myHeader"]).contains("hello")
Expand Down
Expand Up @@ -14,84 +14,34 @@
* limitations under the License.
*/

package com.netflix.graphql.dgs.internal
package com.netflix.graphql.dgs

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.exc.InvalidDefinitionException
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import com.netflix.graphql.dgs.internal.utils.TimeTracer
import graphql.ExecutionResult
import graphql.ExecutionResultImpl
import graphql.GraphQLError
import graphql.GraphqlErrorBuilder
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.springframework.http.HttpHeaders
import org.springframework.http.HttpStatus
import org.springframework.http.ResponseEntity

class DgsExecutionResult @JvmOverloads constructor(
val executionResult: ExecutionResult,
val headers: HttpHeaders = HttpHeaders(),
val status: HttpStatus = HttpStatus.OK
class DgsExecutionResult constructor(
private val executionResult: ExecutionResult,
private var headers: HttpHeaders,
val status: HttpStatus
) : ExecutionResult by executionResult {
companion object {
// defined in here and DgsRestController, for backwards compatibility.
// keep these two variables synced.
const val DGS_RESPONSE_HEADERS_KEY = "dgs-response-headers"
private val sentinelObject = Any()
private val logger: Logger = LoggerFactory.getLogger(DgsExecutionResult::class.java)
}

init {
addExtensionsHeaderKeyToHeader()
}

constructor(
status: HttpStatus = HttpStatus.OK,
headers: HttpHeaders = HttpHeaders.EMPTY,
errors: List<GraphQLError> = listOf(),
extensions: Map<Any, Any>? = null,

// By default, assign data as a sentinel object.
// If we were to default to null here, this constructor
// would be unable to discriminate between an intentionally null
// response and one that the user left default.
data: Any? = sentinelObject
) : this(
headers = headers,
status = status,
executionResult = ExecutionResultImpl
.newExecutionResult()
.errors(errors)
.extensions(extensions)
.apply {
if (data != sentinelObject) {
data(data)
}
}
.build()
)

// for backwards compat with https://github.com/Netflix/dgs-framework/pull/1261.
private fun addExtensionsHeaderKeyToHeader() {
if (executionResult.extensions?.containsKey(DGS_RESPONSE_HEADERS_KEY) == true) {
val dgsResponseHeaders = executionResult.extensions[DGS_RESPONSE_HEADERS_KEY]

if (dgsResponseHeaders is Map<*, *>) {
dgsResponseHeaders.forEach {
if (it.key != null) {
headers.add(it.key.toString(), it.value?.toString())
}
}
} else {
logger.warn(
"{} must be of type java.util.Map, but was {}",
DGS_RESPONSE_HEADERS_KEY,
dgsResponseHeaders?.javaClass?.name
)
}
}
/** Read-Only reference to the HTTP Headers. */
fun headers(): HttpHeaders {
return HttpHeaders.readOnlyHttpHeaders(headers)
}

fun toSpringResponse(
Expand All @@ -117,7 +67,7 @@ class DgsExecutionResult @JvmOverloads constructor(
)
}

// overridden for compatibility with https://github.com/Netflix/dgs-framework/pull/1261.
// Refer to https://github.com/Netflix/dgs-framework/pull/1261 for further details.
override fun toSpecification(): MutableMap<String, Any> {
val spec = executionResult.toSpecification()

Expand All @@ -133,4 +83,73 @@ class DgsExecutionResult @JvmOverloads constructor(

return spec
}

// Refer to https://github.com/Netflix/dgs-framework/pull/1261 for further details.
private fun addExtensionsHeaderKeyToHeader() {
if (executionResult.extensions?.containsKey(DGS_RESPONSE_HEADERS_KEY) == true) {
val dgsResponseHeaders = executionResult.extensions[DGS_RESPONSE_HEADERS_KEY]
if (dgsResponseHeaders is Map<*, *> && dgsResponseHeaders.isNotEmpty()) {
// If the HttpHeaders are empty/read-only we need to switch to a new instance that allows us
// to store the headers that are part of the GraphQL response _extensions_.
if (headers == HttpHeaders.EMPTY) {
headers = HttpHeaders()
}

dgsResponseHeaders.forEach {
if (it.key != null) {
headers.add(it.key.toString(), it.value?.toString())
}
}
} else {
logger.warn(
"{} must be of type java.util.Map, but was {}",
DGS_RESPONSE_HEADERS_KEY,
dgsResponseHeaders?.javaClass?.name
)
}
}
}

/**
* Facilitate the construction of a [DgsExecutionResult] instance.
*/
class Builder {
var executionResult: ExecutionResult = DEFAULT_EXECUTION_RESULT
private set

fun executionResult(executionResult: ExecutionResult) =
apply { this.executionResult = executionResult }

fun executionResult(executionResultBuilder: ExecutionResultImpl.Builder) =
apply { this.executionResult = executionResultBuilder.build() }

var headers: HttpHeaders = HttpHeaders.EMPTY
private set

fun headers(headers: HttpHeaders) = apply { this.headers = headers }

var status: HttpStatus = HttpStatus.OK
private set

fun status(status: HttpStatus) = apply { this.status = status }

fun build() = DgsExecutionResult(
executionResult = checkNotNull(executionResult),
headers = headers,
status = status
)

companion object {
private val DEFAULT_EXECUTION_RESULT = ExecutionResultImpl.newExecutionResult().build()
}
}

companion object {
// defined in here and DgsRestController, for backwards compatibility. Keep these two variables synced.
const val DGS_RESPONSE_HEADERS_KEY = "dgs-response-headers"
private val logger: Logger = LoggerFactory.getLogger(DgsExecutionResult::class.java)

@JvmStatic
fun builder(): Builder = Builder()
}
}
Expand Up @@ -26,6 +26,7 @@ import com.jayway.jsonpath.Option
import com.jayway.jsonpath.ParseContext
import com.jayway.jsonpath.spi.json.JacksonJsonProvider
import com.jayway.jsonpath.spi.mapper.JacksonMappingProvider
import com.netflix.graphql.dgs.DgsExecutionResult
import com.netflix.graphql.dgs.context.DgsContext
import com.netflix.graphql.dgs.exceptions.DgsBadRequestException
import graphql.ExecutionInput
Expand Down Expand Up @@ -80,14 +81,20 @@ object BaseDgsQueryExecutor {

if (!StringUtils.hasText(query)) {
return CompletableFuture.completedFuture(
DgsExecutionResult(
status = HttpStatus.BAD_REQUEST,
errors = listOf(
DgsBadRequestException
.NULL_OR_EMPTY_QUERY_EXCEPTION
.toGraphQlError()
)
)
DgsExecutionResult
.builder()
.status(HttpStatus.BAD_REQUEST)
.executionResult(
ExecutionResultImpl
.newExecutionResult()
.errors(
listOf(
DgsBadRequestException
.NULL_OR_EMPTY_QUERY_EXCEPTION
.toGraphQlError()
)
)
).build()
)
}

Expand Down
Expand Up @@ -16,6 +16,8 @@

package com.netflix.graphql.dgs.internal

import com.netflix.graphql.dgs.DgsExecutionResult
import graphql.ExecutionResultImpl
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.entry
import org.junit.jupiter.api.Nested
Expand All @@ -27,17 +29,18 @@ class DgsExecutionResultTest {
@Test
fun `should be able to pass null for data`() {
assertThat(
DgsExecutionResult(
data = null
).toSpecification()
DgsExecutionResult
.builder()
.executionResult(ExecutionResultImpl.newExecutionResult().data(null))
.build()
.toSpecification()
).contains(entry("data", null))
}

@Test
fun `should default to not having data`() {
assertThat(
DgsExecutionResult()
.toSpecification()
DgsExecutionResult.builder().build().toSpecification()
).doesNotContainKey("data")
}

Expand All @@ -46,7 +49,10 @@ class DgsExecutionResultTest {
val data = "Check under your chair"

assertThat(
DgsExecutionResult(data = data)
DgsExecutionResult
.builder()
.executionResult(ExecutionResultImpl.newExecutionResult().data(data))
.build()
.toSpecification()
).contains(entry("data", data))
}
Expand All @@ -59,11 +65,8 @@ class DgsExecutionResultTest {
headers.add("We can add headers now??", "Yes we can")

assertThat(
DgsExecutionResult(
headers = headers
).toSpringResponse()
.headers
.toMap()
DgsExecutionResult.builder().headers(headers).build().toSpringResponse()
.headers.toMap()
).containsAllEntriesOf(headers.toMap())
}

Expand All @@ -72,11 +75,8 @@ class DgsExecutionResultTest {
val httpStatusCode = HttpStatus.ALREADY_REPORTED

assertThat(
DgsExecutionResult(
status = httpStatusCode
).toSpringResponse()
.statusCode
.value()
DgsExecutionResult.builder().status(httpStatusCode).build().toSpringResponse()
.statusCode.value()
).isEqualTo(httpStatusCode.value())
}
}
Expand Down