diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java index 40a19881077c..272700b49a44 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/TransportHandlingSockJsService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ import org.springframework.http.HttpStatus; import org.springframework.http.server.ServerHttpRequest; import org.springframework.http.server.ServerHttpResponse; +import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.lang.Nullable; import org.springframework.scheduling.TaskScheduler; import org.springframework.util.Assert; @@ -270,6 +271,7 @@ else if (transportType.supportsCors()) { } SockJsSession session = this.sessions.get(sessionId); + boolean isNewSession = false; if (session == null) { if (transportHandler instanceof SockJsSessionFactory) { Map attributes = new HashMap<>(); @@ -278,6 +280,7 @@ else if (transportType.supportsCors()) { } SockJsSessionFactory sessionFactory = (SockJsSessionFactory) transportHandler; session = createSockJsSession(sessionId, sessionFactory, handler, attributes); + isNewSession = true; } else { response.setStatusCode(HttpStatus.NOT_FOUND); @@ -311,6 +314,14 @@ else if (transportType.supportsCors()) { } transportHandler.handleRequest(request, response, handler, session); + + if (isNewSession && (response instanceof ServletServerHttpResponse)) { + int status = ((ServletServerHttpResponse) response).getServletResponse().getStatus(); + if (HttpStatus.valueOf(status).is4xxClientError()) { + this.sessions.remove(sessionId); + } + } + chain.applyAfterHandshake(request, response, null); } catch (SockJsException ex) {