From f0fe2f1df0056dad7fa13c74fa707a1235ea37d9 Mon Sep 17 00:00:00 2001 From: Yifei Zhang Date: Mon, 24 Aug 2020 12:04:36 +0800 Subject: [PATCH] Remove session on 4xx response from WebSocket handshake #25608 --- .../transport/TransportHandlingSockJsService.java | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index 40a19881077c..8a5acabe3383 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -34,6 +34,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.lang.Nullable; import org.springframework.scheduling.TaskScheduler; import org.springframework.util.Assert; @@ -270,6 +271,7 @@ else if (transportType.supportsCors()) { } SockJsSession session = this.sessions.get(sessionId); + boolean newSession = false; if (session == null) { if (transportHandler instanceof SockJsSessionFactory) { Map attributes = new HashMap<>(); @@ -278,6 +280,7 @@ else if (transportType.supportsCors()) { } SockJsSessionFactory sessionFactory = (SockJsSessionFactory) transportHandler; session = createSockJsSession(sessionId, sessionFactory, handler, attributes); + newSession = true; } else { response.setStatusCode(HttpStatus.NOT_FOUND); @@ -311,6 +314,14 @@ else if (transportType.supportsCors()) { } transportHandler.handleRequest(request, response, handler, session); + + if (newSession && (response instanceof ServletServerHttpResponse)) { + int status = ((ServletServerHttpResponse) response).getServletResponse().getStatus(); + if (HttpStatus.valueOf(status).is4xxClientError()) { + sessions.remove(sessionId); + } + } + chain.applyAfterHandshake(request, response, null); } catch (SockJsException ex) {