diff --git a/jetty-openid/src/main/java/org/eclipse/jetty/security/openid/OpenIdAuthenticator.java b/jetty-openid/src/main/java/org/eclipse/jetty/security/openid/OpenIdAuthenticator.java
index ce78c620f53c..8dc458643199 100644
--- a/jetty-openid/src/main/java/org/eclipse/jetty/security/openid/OpenIdAuthenticator.java
+++ b/jetty-openid/src/main/java/org/eclipse/jetty/security/openid/OpenIdAuthenticator.java
@@ -64,6 +64,7 @@ public class OpenIdAuthenticator extends LoginAuthenticator
public static final String CLAIMS = "org.eclipse.jetty.security.openid.claims";
public static final String RESPONSE = "org.eclipse.jetty.security.openid.response";
public static final String ERROR_PAGE = "org.eclipse.jetty.security.openid.error_page";
+ public static final String ALWAYS_SAVE_URI = "org.eclipse.jetty.security.openid.always_save_uri";
public static final String J_URI = "org.eclipse.jetty.security.openid.URI";
public static final String J_POST = "org.eclipse.jetty.security.openid.POST";
public static final String J_METHOD = "org.eclipse.jetty.security.openid.METHOD";
@@ -97,6 +98,10 @@ public void setConfiguration(AuthConfiguration configuration)
if (error != null)
setErrorPage(error);
+ String alwaysSaveUri = configuration.getInitParameter(ALWAYS_SAVE_URI);
+ if (alwaysSaveUri != null)
+ setAlwaysSaveUri(Boolean.parseBoolean(alwaysSaveUri));
+
if (_configuration != null)
return;
diff --git a/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdAuthenticationTest.java b/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdAuthenticationTest.java
index 210bd8629a3d..85defdb0f82e 100644
--- a/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdAuthenticationTest.java
+++ b/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdAuthenticationTest.java
@@ -35,8 +35,10 @@
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;
+@SuppressWarnings("unchecked")
public class OpenIdAuthenticationTest
{
public static final String CLIENT_ID = "testClient101";
@@ -55,6 +57,7 @@ public void setup() throws Exception
server = new Server();
connector = new ServerConnector(server);
+ connector.setPort(8080);
server.addConnector(connector);
ServletContextHandler context = new ServletContextHandler(server, "/", ServletContextHandler.SESSIONS);
@@ -122,30 +125,29 @@ public void stop() throws Exception
@Test
public void testLoginLogout() throws Exception
{
+ openIdProvider.setUser(new OpenIdProvider.User("123456789", "Alice"));
+
String appUriString = "http://localhost:" + connector.getLocalPort();
// Initially not authenticated
ContentResponse response = client.GET(appUriString + "/");
assertThat(response.getStatus(), is(HttpStatus.OK_200));
- String[] content = response.getContentAsString().split("[\r\n]+");
- assertThat(content.length, is(1));
- assertThat(content[0], is("not authenticated"));
+ String content = response.getContentAsString();
+ assertThat(content, containsString("not authenticated"));
// Request to login is success
response = client.GET(appUriString + "/login");
assertThat(response.getStatus(), is(HttpStatus.OK_200));
- content = response.getContentAsString().split("[\r\n]+");
- assertThat(content.length, is(1));
- assertThat(content[0], is("success"));
+ content = response.getContentAsString();
+ assertThat(content, containsString("success"));
// Now authenticated we can get info
response = client.GET(appUriString + "/");
assertThat(response.getStatus(), is(HttpStatus.OK_200));
- content = response.getContentAsString().split("[\r\n]+");
- assertThat(content.length, is(3));
- assertThat(content[0], is("userId: 123456789"));
- assertThat(content[1], is("name: Alice"));
- assertThat(content[2], is("email: Alice@example.com"));
+ content = response.getContentAsString();
+ assertThat(content, containsString("userId: 123456789"));
+ assertThat(content, containsString("name: Alice"));
+ assertThat(content, containsString("email: Alice@example.com"));
// Request to admin page gives 403 as we do not have admin role
response = client.GET(appUriString + "/admin");
@@ -154,9 +156,8 @@ public void testLoginLogout() throws Exception
// We are no longer authenticated after logging out
response = client.GET(appUriString + "/logout");
assertThat(response.getStatus(), is(HttpStatus.OK_200));
- content = response.getContentAsString().split("[\r\n]+");
- assertThat(content.length, is(1));
- assertThat(content[0], is("not authenticated"));
+ content = response.getContentAsString();
+ assertThat(content, containsString("not authenticated"));
}
public static class LoginPage extends HttpServlet
@@ -164,7 +165,9 @@ public static class LoginPage extends HttpServlet
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException
{
+ response.setContentType("text/html");
response.getWriter().println("success");
+ response.getWriter().println("
Home");
}
}
@@ -183,7 +186,7 @@ public static class AdminPage extends HttpServlet
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException
{
- Map userInfo = (Map)request.getSession().getAttribute(OpenIdAuthenticator.CLAIMS);
+ Map userInfo = (Map)request.getSession().getAttribute(OpenIdAuthenticator.CLAIMS);
response.getWriter().println(userInfo.get("sub") + ": success");
}
}
@@ -193,18 +196,20 @@ public static class HomePage extends HttpServlet
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException
{
- response.setContentType("text/plain");
+ response.setContentType("text/html");
Principal userPrincipal = request.getUserPrincipal();
if (userPrincipal != null)
{
- Map userInfo = (Map)request.getSession().getAttribute(OpenIdAuthenticator.CLAIMS);
- response.getWriter().println("userId: " + userInfo.get("sub"));
- response.getWriter().println("name: " + userInfo.get("name"));
- response.getWriter().println("email: " + userInfo.get("email"));
+ Map userInfo = (Map)request.getSession().getAttribute(OpenIdAuthenticator.CLAIMS);
+ response.getWriter().println("userId: " + userInfo.get("sub") + "
");
+ response.getWriter().println("name: " + userInfo.get("name") + "
");
+ response.getWriter().println("email: " + userInfo.get("email") + "
");
+ response.getWriter().println("
Logout");
}
else
{
response.getWriter().println("not authenticated");
+ response.getWriter().println("
Login");
}
}
}
@@ -214,8 +219,9 @@ public static class ErrorPage extends HttpServlet
@Override
protected void doGet(HttpServletRequest request, HttpServletResponse response) throws IOException
{
- response.setContentType("text/plain");
+ response.setContentType("text/html");
response.getWriter().println("not authorized");
+ response.getWriter().println("
Home");
}
}
}
diff --git a/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdProvider.java b/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdProvider.java
index 62eb9ca78e16..d10e873f7249 100644
--- a/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdProvider.java
+++ b/jetty-openid/src/test/java/org/eclipse/jetty/security/openid/OpenIdProvider.java
@@ -14,6 +14,7 @@
package org.eclipse.jetty.security.openid;
import java.io.IOException;
+import java.io.PrintWriter;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
@@ -21,7 +22,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.Random;
+import java.util.Objects;
import java.util.UUID;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
@@ -37,9 +38,13 @@
import org.eclipse.jetty.servlet.ServletHolder;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
+import org.eclipse.jetty.util.log.Log;
+import org.eclipse.jetty.util.log.Logger;
public class OpenIdProvider extends ContainerLifeCycle
{
+ private static final Logger LOG = Log.getLogger(OpenIdProvider.class);
+
private static final String CONFIG_PATH = "/.well-known/openid-configuration";
private static final String AUTH_PATH = "/auth";
private static final String TOKEN_PATH = "/token";
@@ -48,10 +53,32 @@ public class OpenIdProvider extends ContainerLifeCycle
protected final String clientId;
protected final String clientSecret;
protected final List redirectUris = new ArrayList<>();
-
+ private final ServerConnector connector;
+ private final Server server;
+ private int port = 0;
private String provider;
- private Server server;
- private ServerConnector connector;
+ private User preAuthedUser;
+
+ public static void main(String[] args) throws Exception
+ {
+ String clientId = "CLIENT_ID123";
+ String clientSecret = "PASSWORD123";
+ int port = 5771;
+ String redirectUri = "http://localhost:8080/openid/auth";
+
+ OpenIdProvider openIdProvider = new OpenIdProvider(clientId, clientSecret);
+ openIdProvider.addRedirectUri(redirectUri);
+ openIdProvider.setPort(port);
+ openIdProvider.start();
+ try
+ {
+ openIdProvider.join();
+ }
+ finally
+ {
+ openIdProvider.stop();
+ }
+ }
public OpenIdProvider(String clientId, String clientSecret)
{
@@ -72,17 +99,43 @@ public OpenIdProvider(String clientId, String clientSecret)
addBean(server);
}
+ public void join() throws InterruptedException
+ {
+ server.join();
+ }
+
+ public OpenIdConfiguration getOpenIdConfiguration()
+ {
+ String provider = getProvider();
+ String authEndpoint = provider + AUTH_PATH;
+ String tokenEndpoint = provider + TOKEN_PATH;
+ return new OpenIdConfiguration(provider, authEndpoint, tokenEndpoint, clientId, clientSecret, null);
+ }
+
@Override
protected void doStart() throws Exception
{
+ connector.setPort(port);
super.doStart();
provider = "http://localhost:" + connector.getLocalPort();
}
- public String getProvider()
+ public void setPort(int port)
{
- if (!isStarted())
+ if (isStarted())
throw new IllegalStateException();
+ this.port = port;
+ }
+
+ public void setUser(User user)
+ {
+ this.preAuthedUser = user;
+ }
+
+ public String getProvider()
+ {
+ if (!isStarted() && port == 0)
+ throw new IllegalStateException("Port of OpenIdProvider not configured");
return provider;
}
@@ -94,7 +147,7 @@ public void addRedirectUri(String uri)
public class OpenIdAuthEndpoint extends HttpServlet
{
@Override
- protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
+ protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
if (!clientId.equals(req.getParameter("client_id")))
{
@@ -105,6 +158,7 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws Se
String redirectUri = req.getParameter("redirect_uri");
if (!redirectUris.contains(redirectUri))
{
+ LOG.warn("invalid redirectUri {}", redirectUri);
resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid redirect_uri");
return;
}
@@ -130,16 +184,71 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws Se
return;
}
+ if (preAuthedUser == null)
+ {
+ PrintWriter writer = resp.getWriter();
+ resp.setContentType("text/html");
+ writer.println("Login to OpenID Connect Provider
");
+ writer.println("");
+ }
+ else
+ {
+ redirectUser(req, preAuthedUser, redirectUri, state);
+ }
+ }
+
+ @Override
+ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException
+ {
+ String redirectUri = req.getParameter("redirectUri");
+ if (!redirectUris.contains(redirectUri))
+ {
+ resp.sendError(HttpServletResponse.SC_FORBIDDEN, "invalid redirect_uri");
+ return;
+ }
+
+ String state = req.getParameter("state");
+ if (state == null)
+ {
+ resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no state param");
+ return;
+ }
+
+ String username = req.getParameter("username");
+ if (username == null)
+ {
+ resp.sendError(HttpServletResponse.SC_FORBIDDEN, "no username");
+ return;
+ }
+
+ User user = new User(username);
+ redirectUser(req, user, redirectUri, state);
+ }
+
+ public void redirectUser(HttpServletRequest request, User user, String redirectUri, String state) throws IOException
+ {
String authCode = UUID.randomUUID().toString().replace("-", "");
- User user = new User(123456789, "Alice");
issuedAuthCodes.put(authCode, user);
- final Request baseRequest = Request.getBaseRequest(req);
- final Response baseResponse = baseRequest.getResponse();
- redirectUri += "?code=" + authCode + "&state=" + state;
- int redirectCode = (baseRequest.getHttpVersion().getVersion() < HttpVersion.HTTP_1_1.getVersion()
- ? HttpServletResponse.SC_MOVED_TEMPORARILY : HttpServletResponse.SC_SEE_OTHER);
- baseResponse.sendRedirect(redirectCode, resp.encodeRedirectURL(redirectUri));
+ try
+ {
+ final Request baseRequest = Objects.requireNonNull(Request.getBaseRequest(request));
+ final Response baseResponse = baseRequest.getResponse();
+ redirectUri += "?code=" + authCode + "&state=" + state;
+ int redirectCode = (baseRequest.getHttpVersion().getVersion() < HttpVersion.HTTP_1_1.getVersion()
+ ? HttpServletResponse.SC_MOVED_TEMPORARILY : HttpServletResponse.SC_SEE_OTHER);
+ baseResponse.sendRedirect(redirectCode, baseResponse.encodeRedirectURL(redirectUri));
+ }
+ catch (Throwable t)
+ {
+ issuedAuthCodes.remove(authCode);
+ throw t;
+ }
}
}
@@ -171,7 +280,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws S
long expiry = System.currentTimeMillis() + Duration.ofMinutes(10).toMillis();
String response = "{" +
"\"access_token\": \"" + accessToken + "\"," +
- "\"id_token\": \"" + JwtEncoder.encode(user.getIdToken()) + "\"," +
+ "\"id_token\": \"" + JwtEncoder.encode(user.getIdToken(provider, clientId)) + "\"," +
"\"expires_in\": " + expiry + "," +
"\"token_type\": \"Bearer\"" +
"}";
@@ -184,7 +293,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws S
public class OpenIdConfigServlet extends HttpServlet
{
@Override
- protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException
+ protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException
{
String discoveryDocument = "{" +
"\"issuer\": \"" + provider + "\"," +
@@ -196,17 +305,17 @@ protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws Se
}
}
- public class User
+ public static class User
{
- private long subject;
- private String name;
+ private final String subject;
+ private final String name;
public User(String name)
{
- this(new Random().nextLong(), name);
+ this(UUID.nameUUIDFromBytes(name.getBytes()).toString(), name);
}
- public User(long subject, String name)
+ public User(String subject, String name)
{
this.subject = subject;
this.name = name;
@@ -217,10 +326,15 @@ public String getName()
return name;
}
- public String getIdToken()
+ public String getSubject()
+ {
+ return subject;
+ }
+
+ public String getIdToken(String provider, String clientId)
{
long expiry = System.currentTimeMillis() + Duration.ofMinutes(1).toMillis();
- return JwtEncoder.createIdToken(provider, clientId, Long.toString(subject), name, expiry);
+ return JwtEncoder.createIdToken(provider, clientId, subject, name, expiry);
}
}
}