diff --git a/graphql-dgs-spring-webmvc/src/main/kotlin/com/netflix/graphql/dgs/mvc/DgsRestController.kt b/graphql-dgs-spring-webmvc/src/main/kotlin/com/netflix/graphql/dgs/mvc/DgsRestController.kt index c391cd445e..7f108a7e76 100644 --- a/graphql-dgs-spring-webmvc/src/main/kotlin/com/netflix/graphql/dgs/mvc/DgsRestController.kt +++ b/graphql-dgs-spring-webmvc/src/main/kotlin/com/netflix/graphql/dgs/mvc/DgsRestController.kt @@ -85,13 +85,13 @@ open class DgsRestController( produces = [MediaType.APPLICATION_JSON_VALUE] ) fun graphql( - @RequestBody body: String?, + @RequestBody body: ByteArray?, @RequestParam fileParams: Map?, @RequestParam(name = "operations") operation: String?, @RequestParam(name = "map") mapParam: String?, @RequestHeader headers: HttpHeaders, webRequest: WebRequest - ): ResponseEntity { + ): ResponseEntity { logger.debug("Validate HTTP Headers for the GraphQL endpoint...") try { @@ -116,13 +116,14 @@ open class DgsRestController( val queryVariables: Map val extensions: Map if (body != null) { - logger.debug("Reading input value: '{}'", body) + if (logger.isDebugEnabled) { + logger.debug("Reading input value: '{}'", body.decodeToString()) + } if (GraphQLMediaTypes.includesApplicationGraphQL(headers)) { - inputQuery = mapOf("query" to body) + inputQuery = mapOf("query" to body.decodeToString()) queryVariables = emptyMap() extensions = emptyMap() } else { - try { inputQuery = mapper.readValue(body) } catch (ex: Exception) { @@ -226,7 +227,7 @@ open class DgsRestController( val result = try { TimeTracer.logTime( - { mapper.writeValueAsString(executionResult.toSpecification()) }, + { mapper.writeValueAsBytes(executionResult.toSpecification()) }, logger, "Serialized JSON result in {}ms" ) @@ -234,7 +235,7 @@ open class DgsRestController( val errorMessage = "Error serializing response: ${ex.message}" val errorResponse = ExecutionResultImpl(GraphqlErrorBuilder.newError().message(errorMessage).build()) logger.error(errorMessage, ex) - mapper.writeValueAsString(errorResponse.toSpecification()) + mapper.writeValueAsBytes(errorResponse.toSpecification()) } return ResponseEntity.ok(result) diff --git a/graphql-dgs-spring-webmvc/src/test/kotlin/com/netflix/graphql/dgs/mvc/DgsMultipartPostControllerTest.kt b/graphql-dgs-spring-webmvc/src/test/kotlin/com/netflix/graphql/dgs/mvc/DgsMultipartPostControllerTest.kt index 5c8ff42070..7f241dba28 100644 --- a/graphql-dgs-spring-webmvc/src/test/kotlin/com/netflix/graphql/dgs/mvc/DgsMultipartPostControllerTest.kt +++ b/graphql-dgs-spring-webmvc/src/test/kotlin/com/netflix/graphql/dgs/mvc/DgsMultipartPostControllerTest.kt @@ -28,7 +28,6 @@ import org.assertj.core.util.Lists import org.assertj.core.util.Maps import org.intellij.lang.annotations.Language import org.junit.jupiter.api.Assertions.assertThrows -import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith import org.springframework.http.HttpHeaders @@ -41,7 +40,10 @@ import org.springframework.web.multipart.MultipartFile @ExtendWith(MockKExtension::class) class DgsMultipartPostControllerTest { - private val httpHeaders = HttpHeaders() + private val httpHeaders = HttpHeaders().apply { + contentType = MediaType.MULTIPART_FORM_DATA + add(GraphQLCSRFRequestHeaderValidationRule.HEADER_GRAPHQL_REQUIRE_PREFLIGHT, null) + } @MockK lateinit var dgsQueryExecutor: DgsQueryExecutor @@ -49,13 +51,6 @@ class DgsMultipartPostControllerTest { @MockK lateinit var webRequest: WebRequest - @BeforeEach - fun setHeaders() { - httpHeaders.clear() - httpHeaders.contentType = MediaType.MULTIPART_FORM_DATA - httpHeaders.add(GraphQLCSRFRequestHeaderValidationRule.HEADER_GRAPHQL_REQUIRE_PREFLIGHT, null) - } - @Test fun `Multipart form request should require a preflight header`() { @Language("JSON") @@ -93,9 +88,10 @@ class DgsMultipartPostControllerTest { ) assertThat(result.statusCode).isEqualTo(HttpStatus.BAD_REQUEST) - assertThat(result.body) - .isNotEmpty - .contains("Expecting a CSRF Prevention Header but none was found, supported headers are [apollo-require-preflight, x-apollo-operation-name, graphql-require-preflight].") + assertThat(result.body).isInstanceOfSatisfying(String::class.java) { body -> + assertThat(body).isNotEmpty() + assertThat(body).contains("Expecting a CSRF Prevention Header but none was found, supported headers are [apollo-require-preflight, x-apollo-operation-name, graphql-require-preflight].") + } } @Test @@ -135,10 +131,12 @@ class DgsMultipartPostControllerTest { webRequest ) - val mapper = jacksonObjectMapper() - val (data, errors) = mapper.readValue(result.body, GraphQLResponse::class.java) - assertThat(errors.size).isEqualTo(0) - assertThat(data["Response"]).isEqualTo("success") + assertThat(result.body).isInstanceOfSatisfying(ByteArray::class.java) { body -> + val mapper = jacksonObjectMapper() + val (data, errors) = mapper.readValue(body, GraphQLResponse::class.java) + assertThat(errors.size).isEqualTo(0) + assertThat(data["Response"]).isEqualTo("success") + } } @Test @@ -184,10 +182,12 @@ class DgsMultipartPostControllerTest { webRequest ) - val mapper = jacksonObjectMapper() - val (data, errors) = mapper.readValue(result.body, GraphQLResponse::class.java) - assertThat(errors.size).isEqualTo(0) - assertThat(data["Response"]).isEqualTo("success") + assertThat(result.body).isInstanceOfSatisfying(ByteArray::class.java) { body -> + val mapper = jacksonObjectMapper() + val (data, errors) = mapper.readValue(body, GraphQLResponse::class.java) + assertThat(errors.size).isEqualTo(0) + assertThat(data["Response"]).isEqualTo("success") + } } @Test @@ -229,10 +229,12 @@ class DgsMultipartPostControllerTest { webRequest ) - val mapper = jacksonObjectMapper() - val (data, errors) = mapper.readValue(result.body, GraphQLResponse::class.java) - assertThat(errors.size).isEqualTo(0) - assertThat(data["Response"]).isEqualTo("success") + assertThat(result.body).isInstanceOfSatisfying(ByteArray::class.java) { body -> + val mapper = jacksonObjectMapper() + val (data, errors) = mapper.readValue(body, GraphQLResponse::class.java) + assertThat(errors.size).isEqualTo(0) + assertThat(data["Response"]).isEqualTo("success") + } } @Test diff --git a/graphql-dgs-spring-webmvc/src/test/kotlin/com/netflix/graphql/dgs/mvc/DgsRestControllerTest.kt b/graphql-dgs-spring-webmvc/src/test/kotlin/com/netflix/graphql/dgs/mvc/DgsRestControllerTest.kt index e53edad671..f591e16e23 100644 --- a/graphql-dgs-spring-webmvc/src/test/kotlin/com/netflix/graphql/dgs/mvc/DgsRestControllerTest.kt +++ b/graphql-dgs-spring-webmvc/src/test/kotlin/com/netflix/graphql/dgs/mvc/DgsRestControllerTest.kt @@ -26,7 +26,6 @@ import io.mockk.impl.annotations.MockK import io.mockk.junit5.MockKExtension import io.mockk.slot import org.assertj.core.api.Assertions.assertThat -import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith import org.springframework.http.HttpHeaders @@ -38,7 +37,9 @@ import org.springframework.web.context.request.WebRequest @ExtendWith(MockKExtension::class) class DgsRestControllerTest { - private val httpHeaders = HttpHeaders() + private val httpHeaders = HttpHeaders().apply { + contentType = MediaType.APPLICATION_JSON + } @MockK lateinit var dgsQueryExecutor: DgsQueryExecutor @@ -46,12 +47,6 @@ class DgsRestControllerTest { @MockK lateinit var webRequest: WebRequest - @BeforeEach - fun setHeaders() { - httpHeaders.clear() - httpHeaders.contentType = MediaType.APPLICATION_JSON - } - @Test fun `Is able to execute a a well formed query`() { val queryString = "query { hello }" @@ -72,11 +67,14 @@ class DgsRestControllerTest { ) } returns ExecutionResultImpl.newExecutionResult().data(mapOf(Pair("hello", "hello"))).build() - val result = DgsRestController(dgsQueryExecutor).graphql(requestBody, null, null, null, httpHeaders, webRequest) - val mapper = jacksonObjectMapper() - val (data, errors) = mapper.readValue(result.body, GraphQLResponse::class.java) - assertThat(errors.size).isEqualTo(0) - assertThat(data["hello"]).isEqualTo("hello") + val result = DgsRestController(dgsQueryExecutor).graphql(requestBody.toByteArray(), null, null, null, httpHeaders, webRequest) + + assertThat(result.body).isInstanceOfSatisfying(ByteArray::class.java) { body -> + val mapper = jacksonObjectMapper() + val (data, errors) = mapper.readValue(body, GraphQLResponse::class.java) + assertThat(errors.size).isEqualTo(0) + assertThat(data["hello"]).isEqualTo("hello") + } } @Test @@ -101,14 +99,15 @@ class DgsRestControllerTest { ) } returns ExecutionResultImpl.newExecutionResult().data(mapOf(Pair("hi", "there"))).build() - val result = - DgsRestController(dgsQueryExecutor).graphql(requestBody, null, null, null, httpHeaders, webRequest) - val mapper = jacksonObjectMapper() - val (data, errors) = mapper.readValue(result.body, GraphQLResponse::class.java) - assertThat(errors.size).isEqualTo(0) - assertThat(data["hi"]).isEqualTo("there") + val result = DgsRestController(dgsQueryExecutor).graphql(requestBody.toByteArray(), null, null, null, httpHeaders, webRequest) - assertThat(capturedOperationName.captured).isEqualTo("operationB") + assertThat(result.body).isInstanceOfSatisfying(ByteArray::class.java) { body -> + val mapper = jacksonObjectMapper() + val (data, errors) = mapper.readValue(body, GraphQLResponse::class.java) + assertThat(errors.size).isEqualTo(0) + assertThat(data["hi"]).isEqualTo("there") + assertThat(capturedOperationName.captured).isEqualTo("operationB") + } } @Test @@ -132,11 +131,14 @@ class DgsRestControllerTest { val headers = HttpHeaders() headers.contentType = MediaType("application", "graphql") - val result = DgsRestController(dgsQueryExecutor).graphql(requestBody, null, null, null, headers, webRequest) - val mapper = jacksonObjectMapper() - val (data, errors) = mapper.readValue(result.body, GraphQLResponse::class.java) - assertThat(errors.size).isEqualTo(0) - assertThat(data["hello"]).isEqualTo("hello") + val result = DgsRestController(dgsQueryExecutor).graphql(requestBody.toByteArray(), null, null, null, headers, webRequest) + + assertThat(result.body).isInstanceOfSatisfying(ByteArray::class.java) { body -> + val mapper = jacksonObjectMapper() + val (data, errors) = mapper.readValue(body, GraphQLResponse::class.java) + assertThat(errors.size).isEqualTo(0) + assertThat(data["hello"]).isEqualTo("hello") + } } @Test @@ -161,7 +163,7 @@ class DgsRestControllerTest { .data(SubscriptionPublisher(null, null)).build() val result = - DgsRestController(dgsQueryExecutor).graphql(requestBody, null, null, null, httpHeaders, webRequest) + DgsRestController(dgsQueryExecutor).graphql(requestBody.toByteArray(), null, null, null, httpHeaders, webRequest) assertThat(result.statusCode).isEqualTo(HttpStatus.BAD_REQUEST) assertThat(result.body).isEqualTo("Trying to execute subscription on /graphql. Use /subscriptions instead!") } @@ -171,7 +173,7 @@ class DgsRestControllerTest { val requestBody = "" val result = DgsRestController(dgsQueryExecutor) - .graphql(requestBody, null, null, null, httpHeaders, webRequest) + .graphql(requestBody.toByteArray(), null, null, null, httpHeaders, webRequest) assertThat(result) .isInstanceOf(ResponseEntity::class.java) diff --git a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DefaultDgsQueryExecutor.kt b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DefaultDgsQueryExecutor.kt index 17e242a47f..d783d188ca 100644 --- a/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DefaultDgsQueryExecutor.kt +++ b/graphql-dgs/src/main/kotlin/com/netflix/graphql/dgs/internal/DefaultDgsQueryExecutor.kt @@ -113,7 +113,7 @@ class DefaultDgsQueryExecutor( override fun executeAndExtractJsonPath(query: String, jsonPath: String, servletWebRequest: ServletWebRequest): T { val httpHeaders = HttpHeaders() servletWebRequest.headerNames.forEach { name -> - httpHeaders[name] = servletWebRequest.getHeaderValues(name).asList() + httpHeaders.addAll(name, servletWebRequest.getHeaderValues(name).orEmpty().toList()) } return JsonPath.read(getJsonResult(query, emptyMap(), httpHeaders, servletWebRequest), jsonPath) }