diff --git a/jetty-openid/src/main/java/org/eclipse/jetty/security/openid/OpenIdAuthenticator.java b/jetty-openid/src/main/java/org/eclipse/jetty/security/openid/OpenIdAuthenticator.java index ce78c620f53c..8dc458643199 100644 --- a/jetty-openid/src/main/java/org/eclipse/jetty/security/openid/OpenIdAuthenticator.java +++ b/jetty-openid/src/main/java/org/eclipse/jetty/security/openid/OpenIdAuthenticator.java @@ -64,6 +64,7 @@ public class OpenIdAuthenticator extends LoginAuthenticator public static final String CLAIMS = "org.eclipse.jetty.security.openid.claims"; public static final String RESPONSE = "org.eclipse.jetty.security.openid.response"; public static final String ERROR_PAGE = "org.eclipse.jetty.security.openid.error_page"; + public static final String ALWAYS_SAVE_URI = "org.eclipse.jetty.security.openid.always_save_uri"; public static final String J_URI = "org.eclipse.jetty.security.openid.URI"; public static final String J_POST = "org.eclipse.jetty.security.openid.POST"; public static final String J_METHOD = "org.eclipse.jetty.security.openid.METHOD"; @@ -97,6 +98,10 @@ public void setConfiguration(AuthConfiguration configuration) if (error != null) setErrorPage(error); + String alwaysSaveUri = configuration.getInitParameter(ALWAYS_SAVE_URI); + if (alwaysSaveUri != null) + setAlwaysSaveUri(Boolean.parseBoolean(alwaysSaveUri)); + if (_configuration != null) return; diff --git a/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdAuthenticationTest.java b/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdAuthenticationTest.java index 210bd8629a3d..85defdb0f82e 100644 --- a/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdAuthenticationTest.java +++ b/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdAuthenticationTest.java @@ -35,8 +35,10 @@ import org.junit.jupiter.api.Test; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; +@SuppressWarnings("unchecked") public class OpenIdAuthenticationTest { public static final String CLIENT_ID = "testClient101"; @@ -55,6 +57,7 @@ public void setup() throws Exception server = new Server(); connector = new ServerConnector(server); + connector.setPort(8080); server.addConnector(connector); ServletContextHandler context = new ServletContextHandler(server, "/", ServletContextHandler.SESSIONS); @@ -122,30 +125,29 @@ public void stop() throws Exception @Test public void testLoginLogout() throws Exception { + openIdProvider.setUser(new OpenIdProvider.User("123456789", "Alice")); + String appUriString = "http://localhost:" + connector.getLocalPort(); // Initially not authenticated ContentResponse response = client.GET(appUriString + "/"); assertThat(response.getStatus(), is(HttpStatus.OK_200)); - String[] content = response.getContentAsString().split("[\r\n]+"); - assertThat(content.length, is(1)); - assertThat(content[0], is("not authenticated")); + String content = response.getContentAsString(); + assertThat(content, containsString("not authenticated")); // Request to login is success response = client.GET(appUriString + "/login"); assertThat(response.getStatus(), is(HttpStatus.OK_200)); - content = response.getContentAsString().split("[\r\n]+"); - assertThat(content.length, is(1)); - assertThat(content[0], is("success")); + content = response.getContentAsString(); + assertThat(content, containsString("success")); // Now authenticated we can get info response = client.GET(appUriString + "/"); assertThat(response.getStatus(), is(HttpStatus.OK_200)); - content = response.getContentAsString().split("[\r\n]+"); - assertThat(content.length, is(3)); - assertThat(content[0], is("userId: 123456789")); - assertThat(content[1], is("name: Alice")); - assertThat(content[2], is("email: Alice@example.com")); + content = response.getContentAsString(); + assertThat(content, containsString("userId: 123456789")); + assertThat(content, containsString("name: Alice")); + assertThat(content, containsString("email: Alice@example.com")); // Request to admin page gives 403 as we do not have admin role response = client.GET(appUriString + "/admin"); @@ -154,9 +156,8 @@ public void testLoginLogout() throws Exception // We are no longer authenticated after logging out response = client.GET(appUriString + "/logout"); assertThat(response.getStatus(), is(HttpStatus.OK_200)); - content = response.getContentAsString().split("[\r\n]+"); - assertThat(content.length, is(1)); - assertThat(content[0], is("not authenticated")); + content = response.getContentAsString(); + assertThat(content, containsString("not authenticated")); } public static class LoginPage extends HttpServlet @@ -164,7 +165,9 @@ public static class LoginPage extends HttpServlet @Override protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException { + response.setContentType("text/html"); response.getWriter().println("success"); + response.getWriter().println("
Home"); } } @@ -183,7 +186,7 @@ public static class AdminPage extends HttpServlet @Override protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException { - Map userInfo = (Map)request.getSession().getAttribute(OpenIdAuthenticator.CLAIMS); + Map userInfo = (Map)request.getSession().getAttribute(OpenIdAuthenticator.CLAIMS); response.getWriter().println(userInfo.get("sub") + ": success"); } } @@ -193,18 +196,20 @@ public static class HomePage extends HttpServlet @Override protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException { - response.setContentType("text/plain"); + response.setContentType("text/html"); Principal userPrincipal = request.getUserPrincipal(); if (userPrincipal != null) { - Map userInfo = (Map)request.getSession().getAttribute(OpenIdAuthenticator.CLAIMS); - response.getWriter().println("userId: " + userInfo.get("sub")); - response.getWriter().println("name: " + userInfo.get("name")); - response.getWriter().println("email: " + userInfo.get("email")); + Map userInfo = (Map)request.getSession().getAttribute(OpenIdAuthenticator.CLAIMS); + response.getWriter().println("userId: " + userInfo.get("sub") + "
"); + response.getWriter().println("name: " + userInfo.get("name") + "
"); + response.getWriter().println("email: " + userInfo.get("email") + "
"); + response.getWriter().println("
Logout"); } else { response.getWriter().println("not authenticated"); + response.getWriter().println("
Login"); } } } @@ -214,8 +219,9 @@ public static class ErrorPage extends HttpServlet @Override protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException { - response.setContentType("text/plain"); + response.setContentType("text/html"); response.getWriter().println("not authorized"); + response.getWriter().println("
Home"); } } } diff --git a/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdProvider.java b/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdProvider.java index 62eb9ca78e16..d10e873f7249 100644 --- a/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdProvider.java +++ b/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdProvider.java @@ -14,6 +14,7 @@ package org.eclipse.jetty.security.openid; import java.io.IOException; +import java.io.PrintWriter; import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; @@ -21,7 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Random; +import java.util.Objects; import java.util.UUID; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; @@ -37,9 +38,13 @@ import org.eclipse.jetty.servlet.ServletHolder; import org.eclipse.jetty.util.StringUtil; import org.eclipse.jetty.util.component.ContainerLifeCycle; +import org.eclipse.jetty.util.log.Log; +import org.eclipse.jetty.util.log.Logger; public class OpenIdProvider extends ContainerLifeCycle { + private static final Logger LOG = Log.getLogger(OpenIdProvider.class); + private static final String CONFIG_PATH = "/.well-known/openid-configuration"; private static final String AUTH_PATH = "/auth"; private static final String TOKEN_PATH = "/token"; @@ -48,10 +53,32 @@ public class OpenIdProvider extends ContainerLifeCycle protected final String clientId; protected final String clientSecret; protected final List redirectUris = new ArrayList<>(); - + private final ServerConnector connector; + private final Server server; + private int port = 0; private String provider; - private Server server; - private ServerConnector connector; + private User preAuthedUser; + + public static void main(String[] args) throws Exception + { + String clientId = "CLIENT_ID123"; + String clientSecret = "PASSWORD123"; + int port = 5771; + String redirectUri = "http://localhost:8080/openid/auth"; + + OpenIdProvider openIdProvider = new OpenIdProvider(clientId, clientSecret); + openIdProvider.addRedirectUri(redirectUri); + openIdProvider.setPort(port); + openIdProvider.start(); + try + { + openIdProvider.join(); + } + finally + { + openIdProvider.stop(); + } + } public OpenIdProvider(String clientId, String clientSecret) { @@ -72,17 +99,43 @@ public OpenIdProvider(String clientId, String clientSecret) addBean(server); } + public void join() throws InterruptedException + { + server.join(); + } + + public OpenIdConfiguration getOpenIdConfiguration() + { + String provider = getProvider(); + String authEndpoint = provider + AUTH_PATH; + String tokenEndpoint = provider + TOKEN_PATH; + return new OpenIdConfiguration(provider, authEndpoint, tokenEndpoint, clientId, clientSecret, null); + } + @Override protected void doStart() throws Exception { + connector.setPort(port); super.doStart(); provider = "http://localhost:" + connector.getLocalPort(); } - public String getProvider() + public void setPort(int port) { - if (!isStarted()) + if (isStarted()) throw new IllegalStateException(); + this.port = port; + } + + public void setUser(User user) + { + this.preAuthedUser = user; + } + + public String getProvider() + { + if (!isStarted() && port == 0) + throw new IllegalStateException("Port of OpenIdProvider not configured"); return provider; } @@ -94,7 +147,7 @@ public void addRedirectUri(String uri) public class OpenIdAuthEndpoint extends HttpServlet { @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException { if (!clientId.equals(req.getParameter("client_id"))) { @@ -105,6 +158,7 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws Se String redirectUri = req.getParameter("redirect_uri"); if (!redirectUris.contains(redirectUri)) { + LOG.warn("invalid redirectUri {}", redirectUri); resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid redirect_uri"); return; } @@ -130,16 +184,71 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws Se return; } + if (preAuthedUser == null) + { + PrintWriter writer = resp.getWriter(); + resp.setContentType("text/html"); + writer.println("

Login to OpenID Connect Provider

"); + writer.println("
"); + writer.println(""); + writer.println(""); + writer.println(""); + writer.println(""); + writer.println("
"); + } + else + { + redirectUser(req, preAuthedUser, redirectUri, state); + } + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException + { + String redirectUri = req.getParameter("redirectUri"); + if (!redirectUris.contains(redirectUri)) + { + resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid redirect_uri"); + return; + } + + String state = req.getParameter("state"); + if (state == null) + { + resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no state param"); + return; + } + + String username = req.getParameter("username"); + if (username == null) + { + resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no username"); + return; + } + + User user = new User(username); + redirectUser(req, user, redirectUri, state); + } + + public void redirectUser(HttpServletRequest request, User user, String redirectUri, String state) throws IOException + { String authCode = UUID.randomUUID().toString().replace("-", ""); - User user = new User(123456789, "Alice"); issuedAuthCodes.put(authCode, user); - final Request baseRequest = Request.getBaseRequest(req); - final Response baseResponse = baseRequest.getResponse(); - redirectUri += "?code=" + authCode + "&state=" + state; - int redirectCode = (baseRequest.getHttpVersion().getVersion() < HttpVersion.HTTP_1_1.getVersion() - ? HttpServletResponse.SC_MOVED_TEMPORARILY : HttpServletResponse.SC_SEE_OTHER); - baseResponse.sendRedirect(redirectCode, resp.encodeRedirectURL(redirectUri)); + try + { + final Request baseRequest = Objects.requireNonNull(Request.getBaseRequest(request)); + final Response baseResponse = baseRequest.getResponse(); + redirectUri += "?code=" + authCode + "&state=" + state; + int redirectCode = (baseRequest.getHttpVersion().getVersion() < HttpVersion.HTTP_1_1.getVersion() + ? HttpServletResponse.SC_MOVED_TEMPORARILY : HttpServletResponse.SC_SEE_OTHER); + baseResponse.sendRedirect(redirectCode, baseResponse.encodeRedirectURL(redirectUri)); + } + catch (Throwable t) + { + issuedAuthCodes.remove(authCode); + throw t; + } } } @@ -171,7 +280,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws S long expiry = System.currentTimeMillis() + Duration.ofMinutes(10).toMillis(); String response = "{" + "\"access_token\": \"" + accessToken + "\"," + - "\"id_token\": \"" + JwtEncoder.encode(user.getIdToken()) + "\"," + + "\"id_token\": \"" + JwtEncoder.encode(user.getIdToken(provider, clientId)) + "\"," + "\"expires_in\": " + expiry + "," + "\"token_type\": \"Bearer\"" + "}"; @@ -184,7 +293,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws S public class OpenIdConfigServlet extends HttpServlet { @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException { String discoveryDocument = "{" + "\"issuer\": \"" + provider + "\"," + @@ -196,17 +305,17 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws Se } } - public class User + public static class User { - private long subject; - private String name; + private final String subject; + private final String name; public User(String name) { - this(new Random().nextLong(), name); + this(UUID.nameUUIDFromBytes(name.getBytes()).toString(), name); } - public User(long subject, String name) + public User(String subject, String name) { this.subject = subject; this.name = name; @@ -217,10 +326,15 @@ public String getName() return name; } - public String getIdToken() + public String getSubject() + { + return subject; + } + + public String getIdToken(String provider, String clientId) { long expiry = System.currentTimeMillis() + Duration.ofMinutes(1).toMillis(); - return JwtEncoder.createIdToken(provider, clientId, Long.toString(subject), name, expiry); + return JwtEncoder.createIdToken(provider, clientId, subject, name, expiry); } } }