Skip to content

Commit

Permalink
Merge pull request #1313 from HuseinJ/master
Browse files Browse the repository at this point in the history
Make session security context available for both WebSocket protocols
  • Loading branch information
srinivasankavitha committed Nov 9, 2022
2 parents 6aff9aa + 8676a18 commit 19a4eb6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
Expand Up @@ -20,6 +20,9 @@ import com.netflix.graphql.dgs.DgsQueryExecutor
import com.netflix.graphql.types.subscription.*
import org.slf4j.LoggerFactory
import org.slf4j.event.Level
import org.springframework.security.core.context.SecurityContext
import org.springframework.security.core.context.SecurityContextHolder
import org.springframework.util.ClassUtils
import org.springframework.web.socket.CloseStatus
import org.springframework.web.socket.SubProtocolCapable
import org.springframework.web.socket.TextMessage
Expand Down Expand Up @@ -77,16 +80,33 @@ class DgsWebSocketHandler(dgsQueryExecutor: DgsQueryExecutor, connectionInitTime
}

public override fun handleTextMessage(session: WebSocketSession, message: TextMessage) {
loadSecurityContextFromSession(session)
if (session.acceptedProtocol.equals(GRAPHQL_SUBSCRIPTIONS_WS_PROTOCOL, ignoreCase = true)) {
return graphqlWSHandler.handleTextMessage(session, message)
} else if (session.acceptedProtocol.equals(GRAPHQL_SUBSCRIPTIONS_TRANSPORT_WS_PROTOCOL, ignoreCase = true)) {
return graphqlTransportWSHandler.handleTextMessage(session, message)
}
}

private fun loadSecurityContextFromSession(session: WebSocketSession) {
if (springSecurityAvailable) {
val securityContext = session.attributes["SPRING_SECURITY_CONTEXT"] as? SecurityContext
if (securityContext != null) {
SecurityContextHolder.setContext(securityContext)
}
}
}

override fun getSubProtocols(): List<String> = listOf(GRAPHQL_SUBSCRIPTIONS_WS_PROTOCOL, GRAPHQL_SUBSCRIPTIONS_TRANSPORT_WS_PROTOCOL)

private companion object {
val logger = LoggerFactory.getLogger(DgsWebSocketHandler::class.java)

private val springSecurityAvailable: Boolean by lazy {
ClassUtils.isPresent(
"org.springframework.security.core.context.SecurityContextHolder",
DgsWebSocketHandler::class.java.classLoader
)
}
}
}
Expand Up @@ -25,9 +25,6 @@ import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription
import org.slf4j.LoggerFactory
import org.slf4j.event.Level
import org.springframework.security.core.context.SecurityContext
import org.springframework.security.core.context.SecurityContextHolder
import org.springframework.util.ClassUtils
import org.springframework.web.socket.TextMessage
import org.springframework.web.socket.WebSocketSession
import org.springframework.web.socket.handler.TextWebSocketHandler
Expand Down Expand Up @@ -55,7 +52,6 @@ class WebsocketGraphQLWSProtocolHandler(private val dgsQueryExecutor: DgsQueryEx

public override fun handleTextMessage(session: WebSocketSession, message: TextMessage) {
val (type, payload, id) = objectMapper.readValue(message.payload, OperationMessage::class.java)
loadSecurityContextFromSession(session)
when (type) {
GQL_CONNECTION_INIT -> {
logger.info("Initialized connection for {}", session.id)
Expand Down Expand Up @@ -88,15 +84,6 @@ class WebsocketGraphQLWSProtocolHandler(private val dgsQueryExecutor: DgsQueryEx
}
}

private fun loadSecurityContextFromSession(session: WebSocketSession) {
if (springSecurityAvailable) {
val securityContext = session.attributes["SPRING_SECURITY_CONTEXT"] as? SecurityContext
if (securityContext != null) {
SecurityContextHolder.setContext(securityContext)
}
}
}

private fun cleanupSubscriptionsForSession(session: WebSocketSession) {
logger.info("Cleaning up for session {}", session.id)
subscriptions[session.id]?.values?.forEach { it.cancel() }
Expand Down Expand Up @@ -162,12 +149,5 @@ class WebsocketGraphQLWSProtocolHandler(private val dgsQueryExecutor: DgsQueryEx
private companion object {
val logger = LoggerFactory.getLogger(WebsocketGraphQLWSProtocolHandler::class.java)
val objectMapper = jacksonObjectMapper()

private val springSecurityAvailable: Boolean by lazy {
ClassUtils.isPresent(
"org.springframework.security.core.context.SecurityContextHolder",
WebsocketGraphQLWSProtocolHandler::class.java.classLoader
)
}
}
}

0 comments on commit 19a4eb6

Please sign in to comment.