Skip to content

Commit

Permalink
Reuse StandardWebSocketUpgradeStrategy as a base class for Tomcat etc
Browse files Browse the repository at this point in the history
Includes non-reflective instantiation of well-known strategy classes.

See gh-29436
  • Loading branch information
jhoeller committed Nov 7, 2022
1 parent 465575f commit a2ac764
Show file tree
Hide file tree
Showing 14 changed files with 379 additions and 382 deletions.
Expand Up @@ -38,12 +38,17 @@
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.reactive.socket.server.upgrade.JettyRequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.upgrade.ReactorNetty2RequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.upgrade.ReactorNettyRequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.upgrade.StandardWebSocketUpgradeStrategy;
import org.springframework.web.reactive.socket.server.upgrade.TomcatRequestUpgradeStrategy;
import org.springframework.web.reactive.socket.server.upgrade.UndertowRequestUpgradeStrategy;
import org.springframework.web.server.MethodNotAllowedException;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.ServerWebInputException;
Expand All @@ -55,6 +60,7 @@
* also be explicitly configured.
*
* @author Rossen Stoyanchev
* @author Juergen Hoeller
* @since 5.0
*/
public class HandshakeWebSocketService implements WebSocketService, Lifecycle {
Expand All @@ -66,28 +72,32 @@ public class HandshakeWebSocketService implements WebSocketService, Lifecycle {
private static final Mono<Map<String, Object>> EMPTY_ATTRIBUTES = Mono.just(Collections.emptyMap());


private static final boolean tomcatPresent;
private static final boolean tomcatWsPresent;

private static final boolean jettyPresent;
private static final boolean jettyWsPresent;

private static final boolean undertowPresent;
private static final boolean undertowWsPresent;

private static final boolean reactorNettyPresent;

private static final boolean reactorNetty2Present;

static {
ClassLoader loader = HandshakeWebSocketService.class.getClassLoader();
tomcatPresent = ClassUtils.isPresent("org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", loader);
jettyPresent = ClassUtils.isPresent("org.eclipse.jetty.websocket.server.JettyWebSocketServerContainer", loader);
undertowPresent = ClassUtils.isPresent("io.undertow.websockets.WebSocketProtocolHandshakeHandler", loader);
reactorNettyPresent = ClassUtils.isPresent("reactor.netty.http.server.HttpServerResponse", loader);
reactorNetty2Present = ClassUtils.isPresent("reactor.netty5.http.server.HttpServerResponse", loader);
ClassLoader classLoader = HandshakeWebSocketService.class.getClassLoader();
tomcatWsPresent = ClassUtils.isPresent(
"org.apache.tomcat.websocket.server.WsHttpUpgradeHandler", classLoader);
jettyWsPresent = ClassUtils.isPresent(
"org.eclipse.jetty.websocket.server.JettyWebSocketServerContainer", classLoader);
undertowWsPresent = ClassUtils.isPresent(
"io.undertow.websockets.WebSocketProtocolHandshakeHandler", classLoader);
reactorNettyPresent = ClassUtils.isPresent(
"reactor.netty.http.server.HttpServerResponse", classLoader);
reactorNetty2Present = ClassUtils.isPresent(
"reactor.netty5.http.server.HttpServerResponse", classLoader);
}


protected static final Log logger = LogFactory.getLog(HandshakeWebSocketService.class);

private static final Log logger = LogFactory.getLog(HandshakeWebSocketService.class);

private final RequestUpgradeStrategy upgradeStrategy;

Expand All @@ -114,40 +124,6 @@ public HandshakeWebSocketService(RequestUpgradeStrategy upgradeStrategy) {
this.upgradeStrategy = upgradeStrategy;
}

static RequestUpgradeStrategy initUpgradeStrategy() {
String className;
if (tomcatPresent) {
className = "TomcatRequestUpgradeStrategy";
}
else if (jettyPresent) {
className = "JettyRequestUpgradeStrategy";
}
else if (undertowPresent) {
className = "UndertowRequestUpgradeStrategy";
}
else if (reactorNettyPresent) {
// As late as possible (Reactor Netty commonly used for WebClient)
className = "ReactorNettyRequestUpgradeStrategy";
}
else if (reactorNetty2Present) {
// As late as possible (Reactor Netty commonly used for WebClient)
className = "ReactorNetty2RequestUpgradeStrategy";
}
else {
throw new IllegalStateException("No suitable default RequestUpgradeStrategy found");
}

try {
className = "org.springframework.web.reactive.socket.server.upgrade." + className;
Class<?> clazz = ClassUtils.forName(className, HandshakeWebSocketService.class.getClassLoader());
return (RequestUpgradeStrategy) ReflectionUtils.accessibleConstructor(clazz).newInstance();
}
catch (Throwable ex) {
throw new IllegalStateException(
"Failed to instantiate RequestUpgradeStrategy: " + className, ex);
}
}


/**
* Return the {@link RequestUpgradeStrategy} for WebSocket requests.
Expand Down Expand Up @@ -292,4 +268,44 @@ private HandshakeInfo createHandshakeInfo(ServerWebExchange exchange, ServerHttp
return new HandshakeInfo(uri, headers, cookies, principal, protocol, remoteAddress, attributes, logPrefix);
}


static RequestUpgradeStrategy initUpgradeStrategy() {
if (tomcatWsPresent) {
return new TomcatRequestUpgradeStrategy();
}
else if (jettyWsPresent) {
return new JettyRequestUpgradeStrategy();
}
else if (undertowWsPresent) {
return new UndertowRequestUpgradeStrategy();
}
else if (reactorNettyPresent) {
// As late as possible (Reactor Netty commonly used for WebClient)
return ReactorNettyStrategyDelegate.forReactorNetty1();
}
else if (reactorNetty2Present) {
// As late as possible (Reactor Netty commonly used for WebClient)
return ReactorNettyStrategyDelegate.forReactorNetty2();
}
else {
// Let's assume Jakarta WebSocket API 2.1+
return new StandardWebSocketUpgradeStrategy();
}
}


/**
* Inner class to avoid a reachable dependency on Reactor Netty API.
*/
private static class ReactorNettyStrategyDelegate {

public static RequestUpgradeStrategy forReactorNetty1() {
return new ReactorNettyRequestUpgradeStrategy();
}

public static RequestUpgradeStrategy forReactorNetty2() {
return new ReactorNetty2RequestUpgradeStrategy();
}
}

}
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,7 +40,7 @@
import org.springframework.web.server.ServerWebExchange;

/**
* A {@link RequestUpgradeStrategy} for Jetty 11.
* A WebSocket {@code RequestUpgradeStrategy} for Jetty 11.
*
* @author Rossen Stoyanchev
* @since 5.3.4
Expand Down
Expand Up @@ -35,7 +35,7 @@
import org.springframework.web.server.ServerWebExchange;

/**
* A {@link RequestUpgradeStrategy} for use with Reactor Netty for Netty 5.
* A WebSocket {@code RequestUpgradeStrategy} for Reactor Netty for Netty 5.
*
* <p>This class is based on {@link ReactorNettyRequestUpgradeStrategy}.
*\
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2021 the original author or authors.
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,7 +35,7 @@
import org.springframework.web.server.ServerWebExchange;

/**
* A {@link RequestUpgradeStrategy} for use with Reactor Netty.
* A WebSocket {@code RequestUpgradeStrategy} for Reactor Netty.
*
* @author Rossen Stoyanchev
* @since 5.0
Expand Down
@@ -0,0 +1,199 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.web.reactive.socket.server.upgrade;

import java.util.Collections;
import java.util.Map;
import java.util.function.Supplier;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.websocket.Endpoint;
import jakarta.websocket.server.ServerContainer;
import jakarta.websocket.server.ServerEndpointConfig;
import reactor.core.publisher.Mono;

import org.springframework.core.io.buffer.DataBufferFactory;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.adapter.ContextWebSocketHandler;
import org.springframework.web.reactive.socket.adapter.StandardWebSocketHandlerAdapter;
import org.springframework.web.reactive.socket.adapter.TomcatWebSocketSession;
import org.springframework.web.reactive.socket.server.RequestUpgradeStrategy;
import org.springframework.web.server.ServerWebExchange;

/**
* A WebSocket {@code RequestUpgradeStrategy} for the Jakarta WebSocket API 2.1+.
*
* <p>This strategy serves as a fallback if no specific server has been detected.
* It can also be used with Jakarta EE 10 level servers such as Tomcat 10.1 and
* Undertow 2.3 directly, relying on their built-in Jakarta WebSocket 2.1 support.
*
* @author Juergen Hoeller
* @author Violeta Georgieva
* @author Rossen Stoyanchev
* @since 6.0
* @see jakarta.websocket.server.ServerContainer#upgradeHttpToWebSocket
*/
public class StandardWebSocketUpgradeStrategy implements RequestUpgradeStrategy {

private static final String SERVER_CONTAINER_ATTR = "jakarta.websocket.server.ServerContainer";


@Nullable
private Long asyncSendTimeout;

@Nullable
private Long maxSessionIdleTimeout;

@Nullable
private Integer maxTextMessageBufferSize;

@Nullable
private Integer maxBinaryMessageBufferSize;

@Nullable
private ServerContainer serverContainer;


/**
* Exposes the underlying config option on
* {@link ServerContainer#setAsyncSendTimeout(long)}.
*/
public void setAsyncSendTimeout(Long timeoutInMillis) {
this.asyncSendTimeout = timeoutInMillis;
}

@Nullable
public Long getAsyncSendTimeout() {
return this.asyncSendTimeout;
}

/**
* Exposes the underlying config option on
* {@link ServerContainer#setDefaultMaxSessionIdleTimeout(long)}.
*/
public void setMaxSessionIdleTimeout(Long timeoutInMillis) {
this.maxSessionIdleTimeout = timeoutInMillis;
}

@Nullable
public Long getMaxSessionIdleTimeout() {
return this.maxSessionIdleTimeout;
}

/**
* Exposes the underlying config option on
* {@link ServerContainer#setDefaultMaxTextMessageBufferSize(int)}.
*/
public void setMaxTextMessageBufferSize(Integer bufferSize) {
this.maxTextMessageBufferSize = bufferSize;
}

@Nullable
public Integer getMaxTextMessageBufferSize() {
return this.maxTextMessageBufferSize;
}

/**
* Exposes the underlying config option on
* {@link ServerContainer#setDefaultMaxBinaryMessageBufferSize(int)}.
*/
public void setMaxBinaryMessageBufferSize(Integer bufferSize) {
this.maxBinaryMessageBufferSize = bufferSize;
}

@Nullable
public Integer getMaxBinaryMessageBufferSize() {
return this.maxBinaryMessageBufferSize;
}

@Override
public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
@Nullable String subProtocol, Supplier<HandshakeInfo> handshakeInfoFactory){

ServerHttpRequest request = exchange.getRequest();
ServerHttpResponse response = exchange.getResponse();

HttpServletRequest servletRequest = ServerHttpRequestDecorator.getNativeRequest(request);
HttpServletResponse servletResponse = ServerHttpResponseDecorator.getNativeResponse(response);

HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
DataBufferFactory bufferFactory = response.bufferFactory();

// Trigger WebFlux preCommit actions and upgrade
return exchange.getResponse().setComplete()
.then(Mono.deferContextual(contextView -> {
Endpoint endpoint = new StandardWebSocketHandlerAdapter(
ContextWebSocketHandler.decorate(handler, contextView),
session -> new TomcatWebSocketSession(session, handshakeInfo, bufferFactory));

String requestURI = servletRequest.getRequestURI();
DefaultServerEndpointConfig config = new DefaultServerEndpointConfig(requestURI, endpoint);
config.setSubprotocols(subProtocol != null ?
Collections.singletonList(subProtocol) : Collections.emptyList());

try {
upgradeHttpToWebSocket(servletRequest, servletResponse, config, Collections.emptyMap());
}
catch (Exception ex) {
return Mono.error(ex);
}
return Mono.empty();
}));
}


protected void upgradeHttpToWebSocket(HttpServletRequest request, HttpServletResponse response,
ServerEndpointConfig endpointConfig, Map<String,String> pathParams) throws Exception {

getContainer(request).upgradeHttpToWebSocket(request, response, endpointConfig, pathParams);
}

protected ServerContainer getContainer(HttpServletRequest request) {
if (this.serverContainer == null) {
Object container = request.getServletContext().getAttribute(SERVER_CONTAINER_ATTR);
Assert.state(container instanceof ServerContainer,
"ServletContext attribute 'jakarta.websocket.server.ServerContainer' not found.");
this.serverContainer = (ServerContainer) container;
initServerContainer(this.serverContainer);
}
return this.serverContainer;
}

private void initServerContainer(ServerContainer serverContainer) {
if (this.asyncSendTimeout != null) {
serverContainer.setAsyncSendTimeout(this.asyncSendTimeout);
}
if (this.maxSessionIdleTimeout != null) {
serverContainer.setDefaultMaxSessionIdleTimeout(this.maxSessionIdleTimeout);
}
if (this.maxTextMessageBufferSize != null) {
serverContainer.setDefaultMaxTextMessageBufferSize(this.maxTextMessageBufferSize);
}
if (this.maxBinaryMessageBufferSize != null) {
serverContainer.setDefaultMaxBinaryMessageBufferSize(this.maxBinaryMessageBufferSize);
}
}

}

0 comments on commit a2ac764

Please sign in to comment.