diff --git a/graphql-dgs-subscriptions-websockets/src/main/kotlin/com/netflix/graphql/dgs/subscriptions/websockets/DgsWebSocketHandler.kt b/graphql-dgs-subscriptions-websockets/src/main/kotlin/com/netflix/graphql/dgs/subscriptions/websockets/DgsWebSocketHandler.kt index d889665e1..6724a4a77 100644 --- a/graphql-dgs-subscriptions-websockets/src/main/kotlin/com/netflix/graphql/dgs/subscriptions/websockets/DgsWebSocketHandler.kt +++ b/graphql-dgs-subscriptions-websockets/src/main/kotlin/com/netflix/graphql/dgs/subscriptions/websockets/DgsWebSocketHandler.kt @@ -20,6 +20,9 @@ import com.netflix.graphql.dgs.DgsQueryExecutor import com.netflix.graphql.types.subscription.* import org.slf4j.LoggerFactory import org.slf4j.event.Level +import org.springframework.security.core.context.SecurityContext +import org.springframework.security.core.context.SecurityContextHolder +import org.springframework.util.ClassUtils import org.springframework.web.socket.CloseStatus import org.springframework.web.socket.SubProtocolCapable import org.springframework.web.socket.TextMessage @@ -77,6 +80,7 @@ class DgsWebSocketHandler(dgsQueryExecutor: DgsQueryExecutor, connectionInitTime } public override fun handleTextMessage(session: WebSocketSession, message: TextMessage) { + loadSecurityContextFromSession(session) if (session.acceptedProtocol.equals(GRAPHQL_SUBSCRIPTIONS_WS_PROTOCOL, ignoreCase = true)) { return graphqlWSHandler.handleTextMessage(session, message) } else if (session.acceptedProtocol.equals(GRAPHQL_SUBSCRIPTIONS_TRANSPORT_WS_PROTOCOL, ignoreCase = true)) { @@ -84,9 +88,25 @@ class DgsWebSocketHandler(dgsQueryExecutor: DgsQueryExecutor, connectionInitTime } } + private fun loadSecurityContextFromSession(session: WebSocketSession) { + if (springSecurityAvailable) { + val securityContext = session.attributes["SPRING_SECURITY_CONTEXT"] as? SecurityContext + if (securityContext != null) { + SecurityContextHolder.setContext(securityContext) + } + } + } + override fun getSubProtocols(): List = listOf(GRAPHQL_SUBSCRIPTIONS_WS_PROTOCOL, GRAPHQL_SUBSCRIPTIONS_TRANSPORT_WS_PROTOCOL) private companion object { val logger = LoggerFactory.getLogger(DgsWebSocketHandler::class.java) + + private val springSecurityAvailable: Boolean by lazy { + ClassUtils.isPresent( + "org.springframework.security.core.context.SecurityContextHolder", + DgsWebSocketHandler::class.java.classLoader + ) + } } } diff --git a/graphql-dgs-subscriptions-websockets/src/main/kotlin/com/netflix/graphql/dgs/subscriptions/websockets/WebsocketGraphQLWSProtocolHandler.kt b/graphql-dgs-subscriptions-websockets/src/main/kotlin/com/netflix/graphql/dgs/subscriptions/websockets/WebsocketGraphQLWSProtocolHandler.kt index a6071cf21..5d2516a02 100644 --- a/graphql-dgs-subscriptions-websockets/src/main/kotlin/com/netflix/graphql/dgs/subscriptions/websockets/WebsocketGraphQLWSProtocolHandler.kt +++ b/graphql-dgs-subscriptions-websockets/src/main/kotlin/com/netflix/graphql/dgs/subscriptions/websockets/WebsocketGraphQLWSProtocolHandler.kt @@ -25,9 +25,6 @@ import org.reactivestreams.Subscriber import org.reactivestreams.Subscription import org.slf4j.LoggerFactory import org.slf4j.event.Level -import org.springframework.security.core.context.SecurityContext -import org.springframework.security.core.context.SecurityContextHolder -import org.springframework.util.ClassUtils import org.springframework.web.socket.TextMessage import org.springframework.web.socket.WebSocketSession import org.springframework.web.socket.handler.TextWebSocketHandler @@ -55,7 +52,6 @@ class WebsocketGraphQLWSProtocolHandler(private val dgsQueryExecutor: DgsQueryEx public override fun handleTextMessage(session: WebSocketSession, message: TextMessage) { val (type, payload, id) = objectMapper.readValue(message.payload, OperationMessage::class.java) - loadSecurityContextFromSession(session) when (type) { GQL_CONNECTION_INIT -> { logger.info("Initialized connection for {}", session.id) @@ -88,15 +84,6 @@ class WebsocketGraphQLWSProtocolHandler(private val dgsQueryExecutor: DgsQueryEx } } - private fun loadSecurityContextFromSession(session: WebSocketSession) { - if (springSecurityAvailable) { - val securityContext = session.attributes["SPRING_SECURITY_CONTEXT"] as? SecurityContext - if (securityContext != null) { - SecurityContextHolder.setContext(securityContext) - } - } - } - private fun cleanupSubscriptionsForSession(session: WebSocketSession) { logger.info("Cleaning up for session {}", session.id) subscriptions[session.id]?.values?.forEach { it.cancel() } @@ -162,12 +149,5 @@ class WebsocketGraphQLWSProtocolHandler(private val dgsQueryExecutor: DgsQueryEx private companion object { val logger = LoggerFactory.getLogger(WebsocketGraphQLWSProtocolHandler::class.java) val objectMapper = jacksonObjectMapper() - - private val springSecurityAvailable: Boolean by lazy { - ClassUtils.isPresent( - "org.springframework.security.core.context.SecurityContextHolder", - WebsocketGraphQLWSProtocolHandler::class.java.classLoader - ) - } } }