diff --git a/graphql-dgs-client/dependencies.lock b/graphql-dgs-client/dependencies.lock index 7f52cd843..77299a450 100644 --- a/graphql-dgs-client/dependencies.lock +++ b/graphql-dgs-client/dependencies.lock @@ -663,6 +663,9 @@ "locked": "1.12.3" }, "io.projectreactor:reactor-core": { + "firstLevelTransitive": [ + "com.netflix.graphql.dgs:graphql-dgs-subscriptions-sse" + ], "locked": "3.4.10" }, "io.projectreactor:reactor-test": { diff --git a/graphql-dgs-client/src/test/kotlin/com/netflix/graphql/dgs/client/SSESubscriptionGraphQLClientTest.kt b/graphql-dgs-client/src/test/kotlin/com/netflix/graphql/dgs/client/SSESubscriptionGraphQLClientTest.kt index 976f6c74f..0a472a4f8 100644 --- a/graphql-dgs-client/src/test/kotlin/com/netflix/graphql/dgs/client/SSESubscriptionGraphQLClientTest.kt +++ b/graphql-dgs-client/src/test/kotlin/com/netflix/graphql/dgs/client/SSESubscriptionGraphQLClientTest.kt @@ -27,7 +27,6 @@ import graphql.language.TypeName import graphql.schema.idl.TypeDefinitionRegistry import org.junit.jupiter.api.Assertions.assertThrows import org.junit.jupiter.api.Test -import org.slf4j.LoggerFactory import org.springframework.boot.autoconfigure.SpringBootApplication import org.springframework.boot.test.context.SpringBootTest import org.springframework.boot.web.server.LocalServerPort @@ -42,11 +41,8 @@ import reactor.test.StepVerifier ) internal class SSESubscriptionGraphQLClientTest { - val logger = LoggerFactory.getLogger(SSESubscriptionGraphQLClient::class.java) - - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") @LocalServerPort - lateinit var port: Integer + var port: Int? = null @Test fun `A successful subscription should publish ticks`() { diff --git a/graphql-dgs-subscriptions-sse-autoconfigure/dependencies.lock b/graphql-dgs-subscriptions-sse-autoconfigure/dependencies.lock index 0f15dbfc8..a76d90b26 100644 --- a/graphql-dgs-subscriptions-sse-autoconfigure/dependencies.lock +++ b/graphql-dgs-subscriptions-sse-autoconfigure/dependencies.lock @@ -372,6 +372,12 @@ ], "project": true }, + "io.projectreactor:reactor-core": { + "firstLevelTransitive": [ + "com.netflix.graphql.dgs:graphql-dgs-subscriptions-sse" + ], + "locked": "3.4.10" + }, "net.datafaker:datafaker": { "firstLevelTransitive": [ "com.netflix.graphql.dgs:graphql-dgs-mocking" @@ -639,6 +645,12 @@ "io.mockk:mockk": { "locked": "1.12.3" }, + "io.projectreactor:reactor-core": { + "firstLevelTransitive": [ + "com.netflix.graphql.dgs:graphql-dgs-subscriptions-sse" + ], + "locked": "3.4.10" + }, "net.datafaker:datafaker": { "firstLevelTransitive": [ "com.netflix.graphql.dgs:graphql-dgs-mocking" diff --git a/graphql-dgs-subscriptions-sse/build.gradle.kts b/graphql-dgs-subscriptions-sse/build.gradle.kts index 3ff55c50a..1ed07a81f 100644 --- a/graphql-dgs-subscriptions-sse/build.gradle.kts +++ b/graphql-dgs-subscriptions-sse/build.gradle.kts @@ -20,6 +20,9 @@ dependencies { implementation("com.fasterxml.jackson.module:jackson-module-kotlin") implementation("org.springframework:spring-web") implementation("org.springframework:spring-webmvc") + implementation("io.projectreactor:reactor-core") testImplementation("io.projectreactor:reactor-test") + testImplementation("org.springframework.boot:spring-boot-starter-test") + testImplementation("org.springframework.boot:spring-boot-starter-tomcat") } diff --git a/graphql-dgs-subscriptions-sse/dependencies.lock b/graphql-dgs-subscriptions-sse/dependencies.lock index 255a22bc7..6c65e6360 100644 --- a/graphql-dgs-subscriptions-sse/dependencies.lock +++ b/graphql-dgs-subscriptions-sse/dependencies.lock @@ -80,6 +80,9 @@ ], "project": true }, + "io.projectreactor:reactor-core": { + "locked": "3.4.10" + }, "org.jetbrains.kotlin:kotlin-bom": { "locked": "1.5.32" }, @@ -368,6 +371,9 @@ ], "project": true }, + "io.projectreactor:reactor-core": { + "locked": "3.4.10" + }, "net.datafaker:datafaker": { "firstLevelTransitive": [ "com.netflix.graphql.dgs:graphql-dgs-mocking" @@ -516,6 +522,9 @@ "io.mockk:mockk": { "locked": "1.12.3" }, + "io.projectreactor:reactor-core": { + "locked": "3.4.10" + }, "io.projectreactor:reactor-test": { "locked": "3.4.10" }, @@ -537,6 +546,9 @@ "org.springframework.boot:spring-boot-starter-test": { "locked": "2.3.12.RELEASE" }, + "org.springframework.boot:spring-boot-starter-tomcat": { + "locked": "2.3.12.RELEASE" + }, "org.springframework.cloud:spring-cloud-dependencies": { "locked": "Hoxton.SR12" }, @@ -626,6 +638,9 @@ "io.mockk:mockk": { "locked": "1.12.3" }, + "io.projectreactor:reactor-core": { + "locked": "3.4.10" + }, "io.projectreactor:reactor-test": { "locked": "3.4.10" }, @@ -677,6 +692,9 @@ "org.springframework.boot:spring-boot-starter-test": { "locked": "2.3.12.RELEASE" }, + "org.springframework.boot:spring-boot-starter-tomcat": { + "locked": "2.3.12.RELEASE" + }, "org.springframework.cloud:spring-cloud-dependencies": { "locked": "Hoxton.SR12" }, diff --git a/graphql-dgs-subscriptions-sse/src/main/kotlin/com/netflix/graphql/dgs/subscriptions/sse/DgsSSESubscriptionHandler.kt b/graphql-dgs-subscriptions-sse/src/main/kotlin/com/netflix/graphql/dgs/subscriptions/sse/DgsSSESubscriptionHandler.kt index 610a5a42c..dbed8eea1 100644 --- a/graphql-dgs-subscriptions-sse/src/main/kotlin/com/netflix/graphql/dgs/subscriptions/sse/DgsSSESubscriptionHandler.kt +++ b/graphql-dgs-subscriptions-sse/src/main/kotlin/com/netflix/graphql/dgs/subscriptions/sse/DgsSSESubscriptionHandler.kt @@ -22,20 +22,25 @@ import com.netflix.graphql.types.subscription.QueryPayload import com.netflix.graphql.types.subscription.SSEDataPayload import graphql.ExecutionResult import graphql.InvalidSyntaxError +import graphql.language.OperationDefinition +import graphql.parser.InvalidSyntaxException +import graphql.parser.Parser import graphql.validation.ValidationError import org.reactivestreams.Publisher -import org.reactivestreams.Subscriber -import org.reactivestreams.Subscription import org.slf4j.Logger import org.slf4j.LoggerFactory import org.springframework.http.MediaType -import org.springframework.http.ResponseEntity +import org.springframework.http.codec.ServerSentEvent import org.springframework.web.bind.annotation.RequestMapping import org.springframework.web.bind.annotation.RequestParam import org.springframework.web.bind.annotation.RestController -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter +import org.springframework.web.server.ServerErrorException +import org.springframework.web.server.ServerWebInputException +import reactor.core.publisher.Flux import java.nio.charset.StandardCharsets -import java.util.* +import java.util.Base64 +import java.util.UUID +import com.netflix.graphql.types.subscription.Error as SseError /** * This class is defined as "open" only for proxy/aop use cases. It is not considered part of the API, and backwards compatibility is not guaranteed. @@ -44,119 +49,73 @@ import java.util.* @RestController open class DgsSSESubscriptionHandler(open val dgsQueryExecutor: DgsQueryExecutor) { - @RequestMapping("/subscriptions", produces = ["text/event-stream"]) - fun subscriptionWithId(@RequestParam("query") queryBase64: String): ResponseEntity { - val emitter = SseEmitter(-1) - val sessionId = UUID.randomUUID().toString() + @RequestMapping("/subscriptions", produces = [MediaType.TEXT_EVENT_STREAM_VALUE]) + fun subscriptionWithId(@RequestParam("query") queryBase64: String): Flux> { val query = try { String(Base64.getDecoder().decode(queryBase64), StandardCharsets.UTF_8) } catch (ex: IllegalArgumentException) { - emitter.send("Error decoding base64 encoded query") - emitter.complete() - return ResponseEntity.badRequest().body(emitter) + throw ServerWebInputException("Error decoding base64-encoded query") } val queryPayload = try { mapper.readValue(query, QueryPayload::class.java) } catch (ex: Exception) { - emitter.send("Error parsing query: ${ex.message}") - emitter.complete() - return ResponseEntity.badRequest().body(emitter) + throw ServerWebInputException("Error parsing query: ${ex.message}") + } + + if (!isSubscriptionQuery(queryPayload.query)) { + throw ServerWebInputException("Invalid query. operation type is not a subscription") } val executionResult: ExecutionResult = dgsQueryExecutor.execute(queryPayload.query, queryPayload.variables) if (executionResult.errors.isNotEmpty()) { - return if ( - executionResult.errors.asSequence().filterIsInstance().any() || - executionResult.errors.asSequence().filterIsInstance().any() - ) { - val errorMessage = "Subscription query failed to validate: ${executionResult.errors.joinToString(", ")}" - emitter.send(errorMessage) - emitter.complete() - ResponseEntity.badRequest().body(emitter) + val errorMessage = if (executionResult.errors.any { error -> error is ValidationError || error is InvalidSyntaxError }) { + "Subscription query failed to validate: ${executionResult.errors.joinToString()}" } else { - val errorMessage = "Error executing subscription query: ${executionResult.errors.joinToString(", ")}" - logger.error(errorMessage) - emitter.send(errorMessage) - emitter.complete() - ResponseEntity.status(500).body(emitter) + "Error executing subscription query: ${executionResult.errors.joinToString()}" } - } - - val subscriber = object : Subscriber { - lateinit var subscription: Subscription - - override fun onSubscribe(s: Subscription) { - logger.info("Started subscription with id {} for request {}", sessionId, queryPayload) - subscription = s - s.request(1) - } - - override fun onNext(t: ExecutionResult) { - val event = SseEmitter.event() - .data( - mapper.writeValueAsString( - SSEDataPayload(data = t.getData(), errors = t.errors, subId = sessionId) - ), - MediaType.APPLICATION_JSON - ).id(UUID.randomUUID().toString()) - emitter.send(event) - - subscription.request(1) - } - - override fun onError(t: Throwable) { - logger.error("Error on subscription {}", sessionId, t) - val event = SseEmitter.event() - .data( - mapper.writeValueAsString( - SSEDataPayload( - data = null, - errors = listOf(Error(t.message)), - subId = sessionId - ) - ), - MediaType.APPLICATION_JSON - ) - - emitter.send(event) - emitter.completeWithError(t) - } - - override fun onComplete() { - emitter.complete() - } - } - - emitter.onError { - logger.warn("Subscription {} had a connection error", sessionId) - subscriber.subscription.cancel() - } - - emitter.onTimeout { - logger.warn("Subscription {} timed out", sessionId) - subscriber.subscription.cancel() + logger.error(errorMessage) + throw ServerWebInputException(errorMessage) } val publisher = try { executionResult.getData>() - } catch (ex: ClassCastException) { - return if (query.contains("subscription")) { - logger.error("Invalid return type for subscription datafetcher. A subscription datafetcher must return a Publisher. The query was $query", ex) - emitter.send("Invalid return type for subscription datafetcher. Was a non-subscription query send to the subscription endpoint?") - emitter.complete() - ResponseEntity.status(500).body(emitter) - } else { - logger.warn("Invalid return type for subscription datafetcher. The query sent doesn't appear to be a subscription query: $query", ex) - emitter.send("Invalid return type for subscription datafetcher. Was a non-subscription query send to the subscription endpoint?") - emitter.complete() - ResponseEntity.badRequest().body(emitter) - } + } catch (exc: ClassCastException) { + logger.error( + "Invalid return type for subscription datafetcher. A subscription datafetcher must return a Publisher. The query was {}", + query, exc + ) + throw ServerErrorException("Invalid return type for subscription datafetcher. Was a non-subscription query send to the subscription endpoint?", exc) } - publisher.subscribe(subscriber) + val subscriptionId = UUID.randomUUID().toString() + return Flux.from(publisher) + .map { + val payload = SSEDataPayload(data = it.getData(), errors = it.errors, subId = subscriptionId) + ServerSentEvent.builder(mapper.writeValueAsString(payload)) + .id(UUID.randomUUID().toString()) + .build() + }.onErrorResume { exc -> + logger.warn("An exception occurred on subscription {}", subscriptionId, exc) + val errorMessage = exc.message ?: "An exception occurred" + val payload = SSEDataPayload(data = null, errors = listOf(SseError(errorMessage)), subId = subscriptionId) + Flux.just( + ServerSentEvent.builder(mapper.writeValueAsString(payload)) + .id(UUID.randomUUID().toString()) + .build() + ) + } + } - return ResponseEntity.ok(emitter) + private fun isSubscriptionQuery(query: String): Boolean { + val document = try { + Parser().parseDocument(query) + } catch (exc: InvalidSyntaxException) { + return false + } + val definitions = document.getDefinitionsOfType(OperationDefinition::class.java) + return definitions.isNotEmpty() && + definitions.all { def -> def.operation == OperationDefinition.Operation.SUBSCRIPTION } } companion object { diff --git a/graphql-dgs-subscriptions-sse/src/test/kotlin/com/netflix/graphql/dgs/subscriptions/sse/DgsSSESubscriptionHandlerTest.kt b/graphql-dgs-subscriptions-sse/src/test/kotlin/com/netflix/graphql/dgs/subscriptions/sse/DgsSSESubscriptionHandlerTest.kt index e2207128b..547706347 100644 --- a/graphql-dgs-subscriptions-sse/src/test/kotlin/com/netflix/graphql/dgs/subscriptions/sse/DgsSSESubscriptionHandlerTest.kt +++ b/graphql-dgs-subscriptions-sse/src/test/kotlin/com/netflix/graphql/dgs/subscriptions/sse/DgsSSESubscriptionHandlerTest.kt @@ -16,95 +16,107 @@ package com.netflix.graphql.dgs.subscriptions.sse +import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper +import com.fasterxml.jackson.module.kotlin.readValue import com.netflix.graphql.dgs.DgsQueryExecutor import com.netflix.graphql.types.subscription.QueryPayload -import graphql.ExecutionResult +import com.netflix.graphql.types.subscription.SSEDataPayload +import graphql.ExecutionResultImpl import graphql.GraphqlErrorBuilder import graphql.validation.ValidationError -import io.mockk.every -import io.mockk.impl.annotations.MockK -import io.mockk.junit5.MockKExtension -import io.mockk.mockk -import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test -import org.junit.jupiter.api.extension.ExtendWith -import org.reactivestreams.Publisher +import org.mockito.ArgumentMatchers.eq +import org.mockito.Mockito.any +import org.mockito.Mockito.`when` +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.boot.autoconfigure.SpringBootApplication +import org.springframework.boot.test.autoconfigure.web.servlet.WebMvcTest +import org.springframework.boot.test.mock.mockito.MockBean +import org.springframework.http.MediaType +import org.springframework.test.web.servlet.MockMvc +import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.asyncDispatch +import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get +import org.springframework.test.web.servlet.result.MockMvcResultMatchers.content +import org.springframework.test.web.servlet.result.MockMvcResultMatchers.request +import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status import reactor.core.publisher.Flux -import java.util.* +import java.util.Base64 -@ExtendWith(MockKExtension::class) +@WebMvcTest(DgsSSESubscriptionHandler::class, DgsSSESubscriptionHandlerTest.App::class) internal class DgsSSESubscriptionHandlerTest { - @MockK + @SpringBootApplication + open class App + + @Autowired + lateinit var mockMvc: MockMvc + + @MockBean lateinit var dgsQueryExecutor: DgsQueryExecutor - @MockK - lateinit var executionResultMock: ExecutionResult + private val mapper: ObjectMapper = jacksonObjectMapper() @Test fun queryError() { - val query = "subscription { stocks { name, price }}" val queryPayload = QueryPayload(operationName = "MySubscription", query = query) - val base64 = Base64.getEncoder().encodeToString(jacksonObjectMapper().writeValueAsBytes(queryPayload)) + val encodedQuery = Base64.getEncoder().encodeToString(mapper.writeValueAsBytes(queryPayload)) + val executionResult = ExecutionResultImpl.newExecutionResult() + .errors(listOf(GraphqlErrorBuilder.newError().message("broken").build())) + .build() - every { dgsQueryExecutor.execute(query, any()) } returns executionResultMock - every { executionResultMock.errors } returns listOf(GraphqlErrorBuilder.newError().message("broken").build()) + `when`(dgsQueryExecutor.execute(eq(query), any())).thenReturn(executionResult) - val responseEntity = DgsSSESubscriptionHandler(dgsQueryExecutor).subscriptionWithId(base64) - assertThat(responseEntity.statusCode.is5xxServerError).isTrue + mockMvc.perform(get("/subscriptions").param("query", encodedQuery)) + .andExpect(status().is4xxClientError) } @Test fun base64Error() { - - val query = "subscription { stocks { name, price }}" - val base64 = "notbase64" - - every { dgsQueryExecutor.execute(query, any()) } returns executionResultMock - - val responseEntity = DgsSSESubscriptionHandler(dgsQueryExecutor).subscriptionWithId(base64) - assertThat(responseEntity.statusCode.is4xxClientError).isTrue + mockMvc.perform(get("/subscriptions").param("query", "notbase64")) + .andExpect(status().is4xxClientError) } @Test fun queryValidationError() { - val query = "subscription { stocks { name, price }}" val queryPayload = QueryPayload(operationName = "MySubscription", query = query) - val base64 = Base64.getEncoder().encodeToString(jacksonObjectMapper().writeValueAsBytes(queryPayload)) + val encodedQuery = Base64.getEncoder().encodeToString(mapper.writeValueAsBytes(queryPayload)) + + val executionResult = ExecutionResultImpl.newExecutionResult() + .errors(listOf(ValidationError.newValidationError().build())) + .build() - every { dgsQueryExecutor.execute(query, any()) } returns executionResultMock - every { executionResultMock.errors } returns listOf(ValidationError.newValidationError().build()) + `when`(dgsQueryExecutor.execute(eq(query), any())).thenReturn(executionResult) - val responseEntity = DgsSSESubscriptionHandler(dgsQueryExecutor).subscriptionWithId(base64) - assertThat(responseEntity.statusCode.is4xxClientError).isTrue + mockMvc.perform(get("/subscriptions").param("query", encodedQuery)) + .andExpect(status().is4xxClientError) } @Test fun invalidJson() { + val encodedQuery = Base64.getEncoder().encodeToString("not json".toByteArray()) - val query = "subscription { stocks { name, price }}" - val base64 = Base64.getEncoder().encodeToString("not json".toByteArray()) - every { dgsQueryExecutor.execute(query, any()) } returns executionResultMock - val responseEntity = DgsSSESubscriptionHandler(dgsQueryExecutor).subscriptionWithId(base64) - assertThat(responseEntity.statusCode.is4xxClientError).isTrue + mockMvc.perform(get("/subscriptions").param("query", encodedQuery)) + .andExpect(status().is4xxClientError) } @Test fun notAPublisherServerError() { - val query = "subscription { stocks { name, price }}" val queryPayload = QueryPayload(operationName = "MySubscription", query = query) - val base64 = Base64.getEncoder().encodeToString(jacksonObjectMapper().writeValueAsBytes(queryPayload)) + val encodedQuery = Base64.getEncoder().encodeToString(mapper.writeValueAsBytes(queryPayload)) + + val executionResult = ExecutionResultImpl.newExecutionResult() + .data("not a publisher") + .build() - every { dgsQueryExecutor.execute(query, any()) } returns executionResultMock - every { executionResultMock.errors } returns emptyList() - every { executionResultMock.getData>() } throws ClassCastException() + `when`(dgsQueryExecutor.execute(eq(query), any())).thenReturn(executionResult) - val responseEntity = DgsSSESubscriptionHandler(dgsQueryExecutor).subscriptionWithId(base64) - assertThat(responseEntity.statusCode.is5xxServerError).isTrue + mockMvc.perform(get("/subscriptions").param("query", encodedQuery)) + .andExpect(status().is5xxServerError) } @Test @@ -112,31 +124,50 @@ internal class DgsSSESubscriptionHandlerTest { // Not a subscription query val query = "query { stocks { name, price }}" val queryPayload = QueryPayload(operationName = "MySubscription", query = query) - val base64 = Base64.getEncoder().encodeToString(jacksonObjectMapper().writeValueAsBytes(queryPayload)) + val encodedQuery = Base64.getEncoder().encodeToString(mapper.writeValueAsBytes(queryPayload)) - every { dgsQueryExecutor.execute(query, any()) } returns executionResultMock - every { executionResultMock.errors } returns emptyList() - every { executionResultMock.getData>() } throws ClassCastException() + val executionResult = ExecutionResultImpl.newExecutionResult() + .data(mapOf("stocks" to listOf(mapOf("name" to "VTI", "price" to 200)))) + .build() - val responseEntity = DgsSSESubscriptionHandler(dgsQueryExecutor).subscriptionWithId(base64) - assertThat(responseEntity.statusCode.is4xxClientError).isTrue + `when`(dgsQueryExecutor.execute(eq(query), any())).thenReturn(executionResult) + + mockMvc.perform(get("/subscriptions").param("query", encodedQuery)) + .andExpect(status().is4xxClientError) } @Test - @Suppress("ReactiveStreamsUnusedPublisher") fun success() { - val query = "query { stocks { name, price }}" + val query = "subscription { stocks { name, price }}" val queryPayload = QueryPayload(operationName = "MySubscription", query = query) - val base64 = Base64.getEncoder().encodeToString(jacksonObjectMapper().writeValueAsBytes(queryPayload)) - - val nestedExecutionResult = mockk() - - every { dgsQueryExecutor.execute(query, any()) } returns executionResultMock - every { executionResultMock.errors } returns emptyList() - every { executionResultMock.getData>() } returns Flux.just(nestedExecutionResult) - every { nestedExecutionResult.getData() } returns "message 1" - - val responseEntity = DgsSSESubscriptionHandler(dgsQueryExecutor).subscriptionWithId(base64) - assertThat(responseEntity.statusCode.is2xxSuccessful).isTrue + val encodedQuery = Base64.getEncoder().encodeToString(mapper.writeValueAsBytes(queryPayload)) + + val publisher = Flux.just( + ExecutionResultImpl.newExecutionResult().data("message 1").build(), + ExecutionResultImpl.newExecutionResult().data("message 2").build() + ) + val executionResult = ExecutionResultImpl.newExecutionResult() + .data(publisher).build() + + `when`(dgsQueryExecutor.execute(eq(query), any())).thenReturn(executionResult) + + val result = mockMvc.perform(get("/subscriptions").param("query", encodedQuery)) + .andExpect(request().asyncStarted()) + .andExpect(status().is2xxSuccessful) + .andReturn() + + mockMvc.perform(asyncDispatch(result)) + .andExpect(content().contentType(MediaType.TEXT_EVENT_STREAM)) + .andReturn() + + val messages = result.response.contentAsString.lineSequence() + .filter { line -> line.startsWith("data:") } + .map { line -> line.substring("data:".length) } + .map { line -> mapper.readValue(line) } + .toList() + + assertEquals(2, messages.size) + assertEquals("message 1", messages[0].data) + assertEquals("message 2", messages[1].data) } }