diff --git a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/netty/SslServerCustomizer.java b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/netty/SslServerCustomizer.java index 137a0d2cf974..f86f4e7918c0 100644 --- a/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/netty/SslServerCustomizer.java +++ b/spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/embedded/netty/SslServerCustomizer.java @@ -16,12 +16,27 @@ package org.springframework.boot.web.embedded.netty; +import java.net.Socket; import java.net.URL; +import java.security.InvalidAlgorithmParameterException; import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.Provider; +import java.security.UnrecoverableKeyException; +import java.security.cert.X509Certificate; import java.util.Arrays; +import java.util.stream.Collectors; +import javax.net.ssl.KeyManager; import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.KeyManagerFactorySpi; +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.SSLEngine; import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedKeyManager; import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.SslContextBuilder; @@ -40,6 +55,7 @@ * * @author Brian Clozel * @author Raheela Aslam + * @author Chris Bono * @since 2.0.0 */ public class SslServerCustomizer implements NettyServerCustomizer { @@ -92,8 +108,10 @@ else if (this.ssl.getClientAuth() == Ssl.ClientAuth.WANT) { protected KeyManagerFactory getKeyManagerFactory(Ssl ssl, SslStoreProvider sslStoreProvider) { try { KeyStore keyStore = getKeyStore(ssl, sslStoreProvider); - KeyManagerFactory keyManagerFactory = KeyManagerFactory - .getInstance(KeyManagerFactory.getDefaultAlgorithm()); + KeyManagerFactory keyManagerFactory = (ssl.getKeyAlias() == null) + ? KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) + : ConfigurableAliasKeyManagerFactory.instance(ssl.getKeyAlias(), + KeyManagerFactory.getDefaultAlgorithm()); char[] keyPassword = (ssl.getKeyPassword() != null) ? ssl.getKeyPassword().toCharArray() : null; if (keyPassword == null && ssl.getKeyStorePassword() != null) { keyPassword = ssl.getKeyStorePassword().toCharArray(); @@ -161,4 +179,120 @@ private KeyStore loadStore(String type, String provider, String resource, String } + /** + * A {@link KeyManagerFactory} that allows a configurable key alias to be used. Due to + * the fact that the actual calls to retrieve the key by alias are done at request + * time the approach is to wrap the actual key managers with a + * {@link ConfigurableAliasKeyManager}. The actual SPI has to be wrapped as well due + * to the fact that {@link KeyManagerFactory#getKeyManagers()} is final. + */ + private static final class ConfigurableAliasKeyManagerFactory extends KeyManagerFactory { + + private static ConfigurableAliasKeyManagerFactory instance(String alias, String algorithm) + throws NoSuchAlgorithmException { + KeyManagerFactory originalFactory = KeyManagerFactory.getInstance(algorithm); + ConfigurableAliasKeyManagerFactorySpi spi = new ConfigurableAliasKeyManagerFactorySpi(originalFactory, + alias); + return new ConfigurableAliasKeyManagerFactory(spi, originalFactory.getProvider(), algorithm); + } + + private ConfigurableAliasKeyManagerFactory(ConfigurableAliasKeyManagerFactorySpi spi, Provider provider, + String algorithm) { + super(spi, provider, algorithm); + } + + } + + private static final class ConfigurableAliasKeyManagerFactorySpi extends KeyManagerFactorySpi { + + private KeyManagerFactory originalFactory; + + private String alias; + + private ConfigurableAliasKeyManagerFactorySpi(KeyManagerFactory originalFactory, String alias) { + this.originalFactory = originalFactory; + this.alias = alias; + } + + @Override + protected void engineInit(KeyStore keyStore, char[] chars) + throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + this.originalFactory.init(keyStore, chars); + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) + throws InvalidAlgorithmParameterException { + throw new InvalidAlgorithmParameterException("Unsupported ManagerFactoryParameters"); + } + + @Override + protected KeyManager[] engineGetKeyManagers() { + return Arrays.stream(this.originalFactory.getKeyManagers()).filter(X509ExtendedKeyManager.class::isInstance) + .map(X509ExtendedKeyManager.class::cast).map(this::wrapKeyManager).collect(Collectors.toList()) + .toArray(new KeyManager[0]); + } + + private ConfigurableAliasKeyManager wrapKeyManager(X509ExtendedKeyManager km) { + return new ConfigurableAliasKeyManager(km, this.alias); + } + + } + + private static final class ConfigurableAliasKeyManager extends X509ExtendedKeyManager { + + private final X509ExtendedKeyManager keyManager; + + private final String alias; + + private ConfigurableAliasKeyManager(X509ExtendedKeyManager keyManager, String alias) { + this.keyManager = keyManager; + this.alias = alias; + } + + @Override + public String chooseEngineClientAlias(String[] strings, Principal[] principals, SSLEngine sslEngine) { + return this.keyManager.chooseEngineClientAlias(strings, principals, sslEngine); + } + + @Override + public String chooseEngineServerAlias(String s, Principal[] principals, SSLEngine sslEngine) { + if (this.alias == null) { + return this.keyManager.chooseEngineServerAlias(s, principals, sslEngine); + } + return this.alias; + } + + @Override + public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { + return this.keyManager.chooseClientAlias(keyType, issuers, socket); + } + + @Override + public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { + return this.keyManager.chooseServerAlias(keyType, issuers, socket); + } + + @Override + public X509Certificate[] getCertificateChain(String alias) { + return this.keyManager.getCertificateChain(alias); + } + + @Override + public String[] getClientAliases(String keyType, Principal[] issuers) { + return this.keyManager.getClientAliases(keyType, issuers); + } + + @Override + public PrivateKey getPrivateKey(String alias) { + return this.keyManager.getPrivateKey(alias); + } + + @Override + public String[] getServerAliases(String keyType, Principal[] issuers) { + return this.keyManager.getServerAliases(keyType, issuers); + } + + } + } diff --git a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/embedded/netty/NettyReactiveWebServerFactoryTests.java b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/embedded/netty/NettyReactiveWebServerFactoryTests.java index 641a75bfb634..2c6743adf3b5 100644 --- a/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/embedded/netty/NettyReactiveWebServerFactoryTests.java +++ b/spring-boot-project/spring-boot/src/test/java/org/springframework/boot/web/embedded/netty/NettyReactiveWebServerFactoryTests.java @@ -16,15 +16,25 @@ package org.springframework.boot.web.embedded.netty; +import java.time.Duration; import java.util.Arrays; +import javax.net.ssl.SSLHandshakeException; + import org.junit.Test; import org.mockito.InOrder; +import reactor.core.publisher.Mono; import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactory; import org.springframework.boot.web.reactive.server.AbstractReactiveWebServerFactoryTests; import org.springframework.boot.web.server.PortInUseException; +import org.springframework.boot.web.server.Ssl; +import org.springframework.http.MediaType; +import org.springframework.http.client.reactive.ReactorClientHttpConnector; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.WebClient; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -37,6 +47,7 @@ * Tests for {@link NettyReactiveWebServerFactory}. * * @author Brian Clozel + * @author Chris Bono */ public class NettyReactiveWebServerFactoryTests extends AbstractReactiveWebServerFactoryTests { @@ -83,4 +94,38 @@ public void useForwardedHeaders() { assertForwardHeaderIsUsed(factory); } + @Test + public void whenSslIsConfiguredWithAValidAliasARequestSucceeds() { + Mono result = testSslWithAlias("test-alias"); + StepVerifier.setDefaultTimeout(Duration.ofSeconds(30)); + StepVerifier.create(result).expectNext("Hello World").verifyComplete(); + } + + @Test + public void whenSslIsConfiguredWithAnInvalidAliasTheSslHandshakeFails() { + Mono result = testSslWithAlias("test-alias-bad"); + StepVerifier.setDefaultTimeout(Duration.ofSeconds(30)); + StepVerifier.create(result).expectErrorMatches((throwable) -> throwable instanceof SSLHandshakeException + && throwable.getMessage().contains("HANDSHAKE_FAILURE")).verify(); + } + + protected Mono testSslWithAlias(String alias) { + String keyStore = "classpath:test.jks"; + String keyPassword = "password"; + NettyReactiveWebServerFactory factory = getFactory(); + Ssl ssl = new Ssl(); + ssl.setKeyStore(keyStore); + ssl.setKeyPassword(keyPassword); + ssl.setKeyAlias(alias); + factory.setSsl(ssl); + this.webServer = factory.getWebServer(new EchoHandler()); + this.webServer.start(); + ReactorClientHttpConnector connector = buildTrustAllSslConnector(); + WebClient client = WebClient.builder().baseUrl("https://localhost:" + this.webServer.getPort()) + .clientConnector(connector).build(); + return client.post().uri("/test").contentType(MediaType.TEXT_PLAIN) + .body(BodyInserters.fromObject("Hello World")).exchange() + .flatMap((response) -> response.bodyToMono(String.class)); + } + }