Skip to content

Commit

Permalink
Merge pull request #1200 from Netflix/feat-subscriptions-graphql-ws
Browse files Browse the repository at this point in the history
Implement `graphql-transport-ws` protocol for websocket subscriptions (webmvc & webflux)
  • Loading branch information
srinivasankavitha committed Aug 29, 2022
2 parents eb6493e + 7938fa7 commit 8fa2368
Show file tree
Hide file tree
Showing 26 changed files with 23,126 additions and 375 deletions.
3 changes: 3 additions & 0 deletions .idea/codeStyles/Project.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21,586 changes: 21,527 additions & 59 deletions graphql-dgs-example-shared/ui-example/package-lock.json

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion graphql-dgs-example-shared/ui-example/package.json
Expand Up @@ -3,14 +3,15 @@
"version": "1.0.0",
"private": true,
"dependencies": {
"@apollo/client": "^3.2.7",
"@apollo/client": "^3.5.10",
"@reach/router": "^1.2.1",
"@types/node": "^12.12.14",
"@types/reach__router": "^1.2.6",
"@types/react": "^16.9.15",
"@types/react-dom": "^16.9.4",
"emotion": "^9.2.12",
"graphql": "^14.4.2",
"graphql-ws": "^5.10.0",
"polished": "^3.4.1",
"react": "^16.12.0",
"react-dom": "^16.12.0",
Expand Down
13 changes: 7 additions & 6 deletions graphql-dgs-example-shared/ui-example/src/index.tsx
Expand Up @@ -28,14 +28,15 @@ import {
useSubscription
} from '@apollo/client';

import {WebSocketLink} from "@apollo/client/link/ws";
import { GraphQLWsLink } from "@apollo/client/link/subscriptions";
import { createClient } from 'graphql-ws';

const httpLink = createHttpLink({uri:'http://localhost:8080/graphql' })
const webSocketLink = new GraphQLWsLink(createClient({
url: 'ws://localhost:8080/subscriptions',
}));

const webSocketLink = new WebSocketLink({
uri: 'ws://localhost:8080/subscriptions'
});

const httpLink = createHttpLink({uri:'http://localhost:8080/graphql' })
const client: ApolloClient<NormalizedCacheObject> = new ApolloClient({
link: split((operation) => {
return operation.operationName === "StockWatch"
Expand Down Expand Up @@ -117,4 +118,4 @@ ReactDOM.render(
<App/>
</ApolloProvider>,
document.getElementById('root'),
);
);
1 change: 1 addition & 0 deletions graphql-dgs-spring-webflux-autoconfigure/build.gradle.kts
Expand Up @@ -17,6 +17,7 @@
dependencies {
api(project(":graphql-dgs"))
api(project(":graphql-dgs-reactive"))
api(project(":graphql-dgs-subscription-types"))

implementation("org.springframework.boot:spring-boot-starter")
implementation("org.springframework:spring-webflux")
Expand Down
Expand Up @@ -204,9 +204,9 @@ open class DgsWebFluxAutoConfiguration(private val configProps: DgsWebfluxConfig
}

@Bean
open fun websocketSubscriptionHandler(dgsReactiveQueryExecutor: DgsReactiveQueryExecutor): SimpleUrlHandlerMapping {
open fun websocketSubscriptionHandler(dgsReactiveQueryExecutor: DgsReactiveQueryExecutor, webfluxConfigurationProperties: DgsWebfluxConfigurationProperties): SimpleUrlHandlerMapping {
val simpleUrlHandlerMapping =
SimpleUrlHandlerMapping(mapOf("/subscriptions" to DgsReactiveWebsocketHandler(dgsReactiveQueryExecutor)))
SimpleUrlHandlerMapping(mapOf("/subscriptions" to DgsReactiveWebsocketHandler(dgsReactiveQueryExecutor, webfluxConfigurationProperties.websocket.connectionInitTimeout)))
simpleUrlHandlerMapping.order = 1
return simpleUrlHandlerMapping
}
Expand Down
Expand Up @@ -20,17 +20,30 @@ import org.springframework.boot.context.properties.ConfigurationProperties
import org.springframework.boot.context.properties.ConstructorBinding
import org.springframework.boot.context.properties.NestedConfigurationProperty
import org.springframework.boot.context.properties.bind.DefaultValue
import java.time.Duration
import javax.annotation.PostConstruct

@ConstructorBinding
@ConfigurationProperties(prefix = "dgs.graphql")
@Suppress("ConfigurationProperties")
class DgsWebfluxConfigurationProperties(
/** Websocket configuration. */
@NestedConfigurationProperty var websocket: DgsWebsocketConfigurationProperties = DgsWebsocketConfigurationProperties(
DEFAULT_CONNECTION_INIT_TIMEOUT_DURATION
),
/** Path to the endpoint that will serve GraphQL requests. */
@DefaultValue("/graphql") var path: String = "/graphql",
@NestedConfigurationProperty var graphiql: DgsGraphiQLConfigurationProperties = DgsGraphiQLConfigurationProperties(),
@NestedConfigurationProperty var schemaJson: DgsSchemaJsonConfigurationProperties = DgsSchemaJsonConfigurationProperties()
) {
/**
* Configuration properties for websockets.
*/
data class DgsWebsocketConfigurationProperties(
/** Connection Initialization timeout for graphql-transport-ws. */
@DefaultValue(DEFAULT_CONNECTION_INIT_TIMEOUT) var connectionInitTimeout: Duration
)

/**
* Configuration properties for the GraphiQL endpoint.
*/
Expand Down Expand Up @@ -60,4 +73,9 @@ class DgsWebfluxConfigurationProperties(
throw IllegalArgumentException("$pathProperty must start with '/' and not end with '/' but was '$path'")
}
}

companion object {
const val DEFAULT_CONNECTION_INIT_TIMEOUT = "10s"
val DEFAULT_CONNECTION_INIT_TIMEOUT_DURATION: Duration = Duration.ofSeconds(10)
}
}
Expand Up @@ -16,165 +16,27 @@

package com.netflix.graphql.dgs.webflux.handlers

import com.fasterxml.jackson.annotation.JsonProperty
import com.fasterxml.jackson.module.kotlin.convertValue
import com.netflix.graphql.dgs.reactive.DgsReactiveQueryExecutor
import graphql.ExecutionResult
import org.reactivestreams.Publisher
import org.reactivestreams.Subscription
import org.slf4j.LoggerFactory
import org.springframework.core.ResolvableType
import org.springframework.core.io.buffer.DataBuffer
import org.springframework.core.io.buffer.DataBufferUtils
import org.springframework.http.codec.json.Jackson2JsonDecoder
import org.springframework.http.codec.json.Jackson2JsonEncoder
import org.springframework.util.MimeTypeUtils
import com.netflix.graphql.types.subscription.GRAPHQL_SUBSCRIPTIONS_TRANSPORT_WS_PROTOCOL
import com.netflix.graphql.types.subscription.GRAPHQL_SUBSCRIPTIONS_WS_PROTOCOL
import org.springframework.web.reactive.socket.WebSocketHandler
import org.springframework.web.reactive.socket.WebSocketMessage
import org.springframework.web.reactive.socket.WebSocketSession
import reactor.core.publisher.Flux
import reactor.core.publisher.Mono
import java.util.concurrent.ConcurrentHashMap
import java.time.Duration

class DgsReactiveWebsocketHandler(private val dgsReactiveQueryExecutor: DgsReactiveQueryExecutor) : WebSocketHandler {
class DgsReactiveWebsocketHandler(dgsReactiveQueryExecutor: DgsReactiveQueryExecutor, connectionInitTimeout: Duration) : WebSocketHandler {

private val resolvableType = ResolvableType.forType(OperationMessage::class.java)
private val subscriptions = ConcurrentHashMap<String, MutableMap<String, Subscription>>()
private val decoder = Jackson2JsonDecoder()
private val encoder = Jackson2JsonEncoder(decoder.objectMapper)

override fun getSubProtocols(): List<String> = listOf("graphql-ws")
private val graphqlWSHandler = WebsocketGraphQLWSProtocolHandler(dgsReactiveQueryExecutor)
private val graphqlTransportWSHandler = WebsocketGraphQLTransportWSProtocolHandler(dgsReactiveQueryExecutor, connectionInitTimeout)
override fun getSubProtocols(): List<String> = listOf(GRAPHQL_SUBSCRIPTIONS_WS_PROTOCOL, GRAPHQL_SUBSCRIPTIONS_TRANSPORT_WS_PROTOCOL)

override fun handle(webSocketSession: WebSocketSession): Mono<Void> {
return webSocketSession.send(
webSocketSession.receive()
.flatMap { message ->
val buffer: DataBuffer = DataBufferUtils.retain(message.payload)

val operationMessage: OperationMessage = decoder.decode(
buffer,
resolvableType,
MimeTypeUtils.APPLICATION_JSON,
null
) as OperationMessage

when (operationMessage.type) {
GQL_CONNECTION_INIT -> Flux.just(
toWebsocketMessage(
OperationMessage(GQL_CONNECTION_ACK), webSocketSession
)
)
GQL_START -> {
val queryPayload = decoder.objectMapper.convertValue<QueryPayload>(
operationMessage.payload ?: error("payload == null")
)
logger.debug("Starting subscription {} for session {}", queryPayload, webSocketSession.id)
dgsReactiveQueryExecutor.execute(queryPayload.query, queryPayload.variables)
.flatMapMany { executionResult ->
val publisher: Publisher<ExecutionResult> = executionResult.getData()
Flux.from(publisher).map { executionResult ->
toWebsocketMessage(
OperationMessage(GQL_DATA, DataPayload(data = executionResult.getData(), errors = executionResult.errors), operationMessage.id),
webSocketSession
)
}.doOnSubscribe {
if (operationMessage.id != null) {
subscriptions[webSocketSession.id] = mutableMapOf(operationMessage.id to it)
}
}.doOnComplete {
webSocketSession.send(
Flux.just(
toWebsocketMessage(
OperationMessage(GQL_COMPLETE, null, operationMessage.id),
webSocketSession
)
)
).subscribe()

subscriptions[webSocketSession.id]?.remove(operationMessage.id)
logger.debug(
"Completing subscription {} for connection {}",
operationMessage.id, webSocketSession.id
)
}.doOnError {
webSocketSession.send(
Flux.just(
toWebsocketMessage(
OperationMessage(GQL_ERROR, DataPayload(null, listOf(it.message!!)), operationMessage.id),
webSocketSession
)
)
).subscribe()

subscriptions[webSocketSession.id]?.remove(operationMessage.id)
logger.debug(
"Subscription publisher error for input {} for subscription {} for connection {}",
queryPayload, operationMessage.id, webSocketSession.id, it
)
}
}
}

GQL_STOP -> {
subscriptions[webSocketSession.id]?.remove(operationMessage.id)?.cancel()
logger.debug(
"Client stopped subscription {} for connection {}",
operationMessage.id, webSocketSession.id
)
Flux.empty()
}

GQL_CONNECTION_TERMINATE -> {
subscriptions[webSocketSession.id]?.values?.forEach { it.cancel() }
subscriptions.remove(webSocketSession.id)
webSocketSession.close()
logger.debug("Connection {} terminated", webSocketSession.id)
Flux.empty()
}
if (webSocketSession.handshakeInfo.subProtocol.equals(GRAPHQL_SUBSCRIPTIONS_WS_PROTOCOL, ignoreCase = true)) {
return graphqlWSHandler.handle(webSocketSession)
} else if (webSocketSession.handshakeInfo.subProtocol.equals(GRAPHQL_SUBSCRIPTIONS_TRANSPORT_WS_PROTOCOL, ignoreCase = true)) {
return graphqlTransportWSHandler.handle(webSocketSession)
}

else -> Flux.empty()
}
}
)
}

private fun toWebsocketMessage(operationMessage: OperationMessage, session: WebSocketSession): WebSocketMessage {
return WebSocketMessage(
WebSocketMessage.Type.TEXT,
encoder.encodeValue(
operationMessage,
session.bufferFactory(),
resolvableType,
MimeTypeUtils.APPLICATION_JSON,
null
)
)
}

companion object {
private val logger = LoggerFactory.getLogger(DgsReactiveQueryExecutor::class.java)

const val GQL_CONNECTION_INIT = "connection_init"
const val GQL_CONNECTION_ACK = "connection_ack"
const val GQL_START = "start"
const val GQL_STOP = "stop"
const val GQL_DATA = "data"
const val GQL_ERROR = "error"
const val GQL_COMPLETE = "complete"
const val GQL_CONNECTION_TERMINATE = "connection_terminate"
return Mono.empty()
}
}

data class DataPayload(val data: Any?, val errors: List<Any>? = emptyList())
data class OperationMessage(
@JsonProperty("type") val type: String,
@JsonProperty("payload") val payload: Any? = null,
@JsonProperty("id", required = false) val id: String? = ""
)

data class QueryPayload(
@JsonProperty("variables") val variables: Map<String, Any> = emptyMap(),
@JsonProperty("extensions") val extensions: Map<String, Any> = emptyMap(),
@JsonProperty("operationName") val operationName: String?,
@JsonProperty("query") val query: String
)

0 comments on commit 8fa2368

Please sign in to comment.