Skip to content

Commit

Permalink
Merge branch 'master' into srossillo.fix-graphiql-issues
Browse files Browse the repository at this point in the history
  • Loading branch information
foo4u committed Aug 15, 2022
2 parents aa88ff6 + e74038f commit 7407753
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 59 deletions.
Expand Up @@ -85,13 +85,13 @@ open class DgsRestController(
produces = [MediaType.APPLICATION_JSON_VALUE]
)
fun graphql(
@RequestBody body: String?,
@RequestBody body: ByteArray?,
@RequestParam fileParams: Map<String, MultipartFile>?,
@RequestParam(name = "operations") operation: String?,
@RequestParam(name = "map") mapParam: String?,
@RequestHeader headers: HttpHeaders,
webRequest: WebRequest
): ResponseEntity<String> {
): ResponseEntity<Any> {

logger.debug("Validate HTTP Headers for the GraphQL endpoint...")
try {
Expand All @@ -116,13 +116,14 @@ open class DgsRestController(
val queryVariables: Map<String, Any>
val extensions: Map<String, Any>
if (body != null) {
logger.debug("Reading input value: '{}'", body)
if (logger.isDebugEnabled) {
logger.debug("Reading input value: '{}'", body.decodeToString())
}
if (GraphQLMediaTypes.includesApplicationGraphQL(headers)) {
inputQuery = mapOf("query" to body)
inputQuery = mapOf("query" to body.decodeToString())
queryVariables = emptyMap()
extensions = emptyMap()
} else {

try {
inputQuery = mapper.readValue(body)
} catch (ex: Exception) {
Expand Down Expand Up @@ -226,15 +227,15 @@ open class DgsRestController(

val result = try {
TimeTracer.logTime(
{ mapper.writeValueAsString(executionResult.toSpecification()) },
{ mapper.writeValueAsBytes(executionResult.toSpecification()) },
logger,
"Serialized JSON result in {}ms"
)
} catch (ex: InvalidDefinitionException) {
val errorMessage = "Error serializing response: ${ex.message}"
val errorResponse = ExecutionResultImpl(GraphqlErrorBuilder.newError().message(errorMessage).build())
logger.error(errorMessage, ex)
mapper.writeValueAsString(errorResponse.toSpecification())
mapper.writeValueAsBytes(errorResponse.toSpecification())
}

return ResponseEntity.ok(result)
Expand Down
Expand Up @@ -28,7 +28,6 @@ import org.assertj.core.util.Lists
import org.assertj.core.util.Maps
import org.intellij.lang.annotations.Language
import org.junit.jupiter.api.Assertions.assertThrows
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import org.springframework.http.HttpHeaders
Expand All @@ -41,21 +40,17 @@ import org.springframework.web.multipart.MultipartFile
@ExtendWith(MockKExtension::class)
class DgsMultipartPostControllerTest {

private val httpHeaders = HttpHeaders()
private val httpHeaders = HttpHeaders().apply {
contentType = MediaType.MULTIPART_FORM_DATA
add(GraphQLCSRFRequestHeaderValidationRule.HEADER_GRAPHQL_REQUIRE_PREFLIGHT, null)
}

@MockK
lateinit var dgsQueryExecutor: DgsQueryExecutor

@MockK
lateinit var webRequest: WebRequest

@BeforeEach
fun setHeaders() {
httpHeaders.clear()
httpHeaders.contentType = MediaType.MULTIPART_FORM_DATA
httpHeaders.add(GraphQLCSRFRequestHeaderValidationRule.HEADER_GRAPHQL_REQUIRE_PREFLIGHT, null)
}

@Test
fun `Multipart form request should require a preflight header`() {
@Language("JSON")
Expand Down Expand Up @@ -93,9 +88,10 @@ class DgsMultipartPostControllerTest {
)

assertThat(result.statusCode).isEqualTo(HttpStatus.BAD_REQUEST)
assertThat(result.body)
.isNotEmpty
.contains("Expecting a CSRF Prevention Header but none was found, supported headers are [apollo-require-preflight, x-apollo-operation-name, graphql-require-preflight].")
assertThat(result.body).isInstanceOfSatisfying(String::class.java) { body ->
assertThat(body).isNotEmpty()
assertThat(body).contains("Expecting a CSRF Prevention Header but none was found, supported headers are [apollo-require-preflight, x-apollo-operation-name, graphql-require-preflight].")
}
}

@Test
Expand Down Expand Up @@ -135,10 +131,12 @@ class DgsMultipartPostControllerTest {
webRequest
)

val mapper = jacksonObjectMapper()
val (data, errors) = mapper.readValue(result.body, GraphQLResponse::class.java)
assertThat(errors.size).isEqualTo(0)
assertThat(data["Response"]).isEqualTo("success")
assertThat(result.body).isInstanceOfSatisfying(ByteArray::class.java) { body ->
val mapper = jacksonObjectMapper()
val (data, errors) = mapper.readValue(body, GraphQLResponse::class.java)
assertThat(errors.size).isEqualTo(0)
assertThat(data["Response"]).isEqualTo("success")
}
}

@Test
Expand Down Expand Up @@ -184,10 +182,12 @@ class DgsMultipartPostControllerTest {
webRequest
)

val mapper = jacksonObjectMapper()
val (data, errors) = mapper.readValue(result.body, GraphQLResponse::class.java)
assertThat(errors.size).isEqualTo(0)
assertThat(data["Response"]).isEqualTo("success")
assertThat(result.body).isInstanceOfSatisfying(ByteArray::class.java) { body ->
val mapper = jacksonObjectMapper()
val (data, errors) = mapper.readValue(body, GraphQLResponse::class.java)
assertThat(errors.size).isEqualTo(0)
assertThat(data["Response"]).isEqualTo("success")
}
}

@Test
Expand Down Expand Up @@ -229,10 +229,12 @@ class DgsMultipartPostControllerTest {
webRequest
)

val mapper = jacksonObjectMapper()
val (data, errors) = mapper.readValue(result.body, GraphQLResponse::class.java)
assertThat(errors.size).isEqualTo(0)
assertThat(data["Response"]).isEqualTo("success")
assertThat(result.body).isInstanceOfSatisfying(ByteArray::class.java) { body ->
val mapper = jacksonObjectMapper()
val (data, errors) = mapper.readValue(body, GraphQLResponse::class.java)
assertThat(errors.size).isEqualTo(0)
assertThat(data["Response"]).isEqualTo("success")
}
}

@Test
Expand Down
Expand Up @@ -26,7 +26,6 @@ import io.mockk.impl.annotations.MockK
import io.mockk.junit5.MockKExtension
import io.mockk.slot
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import org.springframework.http.HttpHeaders
Expand All @@ -38,20 +37,16 @@ import org.springframework.web.context.request.WebRequest
@ExtendWith(MockKExtension::class)
class DgsRestControllerTest {

private val httpHeaders = HttpHeaders()
private val httpHeaders = HttpHeaders().apply {
contentType = MediaType.APPLICATION_JSON
}

@MockK
lateinit var dgsQueryExecutor: DgsQueryExecutor

@MockK
lateinit var webRequest: WebRequest

@BeforeEach
fun setHeaders() {
httpHeaders.clear()
httpHeaders.contentType = MediaType.APPLICATION_JSON
}

@Test
fun `Is able to execute a a well formed query`() {
val queryString = "query { hello }"
Expand All @@ -72,11 +67,14 @@ class DgsRestControllerTest {
)
} returns ExecutionResultImpl.newExecutionResult().data(mapOf(Pair("hello", "hello"))).build()

val result = DgsRestController(dgsQueryExecutor).graphql(requestBody, null, null, null, httpHeaders, webRequest)
val mapper = jacksonObjectMapper()
val (data, errors) = mapper.readValue(result.body, GraphQLResponse::class.java)
assertThat(errors.size).isEqualTo(0)
assertThat(data["hello"]).isEqualTo("hello")
val result = DgsRestController(dgsQueryExecutor).graphql(requestBody.toByteArray(), null, null, null, httpHeaders, webRequest)

assertThat(result.body).isInstanceOfSatisfying(ByteArray::class.java) { body ->
val mapper = jacksonObjectMapper()
val (data, errors) = mapper.readValue(body, GraphQLResponse::class.java)
assertThat(errors.size).isEqualTo(0)
assertThat(data["hello"]).isEqualTo("hello")
}
}

@Test
Expand All @@ -101,14 +99,15 @@ class DgsRestControllerTest {
)
} returns ExecutionResultImpl.newExecutionResult().data(mapOf(Pair("hi", "there"))).build()

val result =
DgsRestController(dgsQueryExecutor).graphql(requestBody, null, null, null, httpHeaders, webRequest)
val mapper = jacksonObjectMapper()
val (data, errors) = mapper.readValue(result.body, GraphQLResponse::class.java)
assertThat(errors.size).isEqualTo(0)
assertThat(data["hi"]).isEqualTo("there")
val result = DgsRestController(dgsQueryExecutor).graphql(requestBody.toByteArray(), null, null, null, httpHeaders, webRequest)

assertThat(capturedOperationName.captured).isEqualTo("operationB")
assertThat(result.body).isInstanceOfSatisfying(ByteArray::class.java) { body ->
val mapper = jacksonObjectMapper()
val (data, errors) = mapper.readValue(body, GraphQLResponse::class.java)
assertThat(errors.size).isEqualTo(0)
assertThat(data["hi"]).isEqualTo("there")
assertThat(capturedOperationName.captured).isEqualTo("operationB")
}
}

@Test
Expand All @@ -132,11 +131,14 @@ class DgsRestControllerTest {

val headers = HttpHeaders()
headers.contentType = MediaType("application", "graphql")
val result = DgsRestController(dgsQueryExecutor).graphql(requestBody, null, null, null, headers, webRequest)
val mapper = jacksonObjectMapper()
val (data, errors) = mapper.readValue(result.body, GraphQLResponse::class.java)
assertThat(errors.size).isEqualTo(0)
assertThat(data["hello"]).isEqualTo("hello")
val result = DgsRestController(dgsQueryExecutor).graphql(requestBody.toByteArray(), null, null, null, headers, webRequest)

assertThat(result.body).isInstanceOfSatisfying(ByteArray::class.java) { body ->
val mapper = jacksonObjectMapper()
val (data, errors) = mapper.readValue(body, GraphQLResponse::class.java)
assertThat(errors.size).isEqualTo(0)
assertThat(data["hello"]).isEqualTo("hello")
}
}

@Test
Expand All @@ -161,7 +163,7 @@ class DgsRestControllerTest {
.data(SubscriptionPublisher(null, null)).build()

val result =
DgsRestController(dgsQueryExecutor).graphql(requestBody, null, null, null, httpHeaders, webRequest)
DgsRestController(dgsQueryExecutor).graphql(requestBody.toByteArray(), null, null, null, httpHeaders, webRequest)
assertThat(result.statusCode).isEqualTo(HttpStatus.BAD_REQUEST)
assertThat(result.body).isEqualTo("Trying to execute subscription on /graphql. Use /subscriptions instead!")
}
Expand All @@ -171,7 +173,7 @@ class DgsRestControllerTest {
val requestBody = ""
val result =
DgsRestController(dgsQueryExecutor)
.graphql(requestBody, null, null, null, httpHeaders, webRequest)
.graphql(requestBody.toByteArray(), null, null, null, httpHeaders, webRequest)

assertThat(result)
.isInstanceOf(ResponseEntity::class.java)
Expand Down
Expand Up @@ -113,7 +113,7 @@ class DefaultDgsQueryExecutor(
override fun <T : Any?> executeAndExtractJsonPath(query: String, jsonPath: String, servletWebRequest: ServletWebRequest): T {
val httpHeaders = HttpHeaders()
servletWebRequest.headerNames.forEach { name ->
httpHeaders[name] = servletWebRequest.getHeaderValues(name).asList()
httpHeaders.addAll(name, servletWebRequest.getHeaderValues(name).orEmpty().toList())
}
return JsonPath.read(getJsonResult(query, emptyMap(), httpHeaders, servletWebRequest), jsonPath)
}
Expand Down

0 comments on commit 7407753

Please sign in to comment.