Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UNDERTOW-2378] Adjust properly session timeout also for alternative AUTH mechanisms #1585

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@

package io.undertow.servlet.handlers.security;

import static io.undertow.security.api.SecurityNotification.EventType.AUTHENTICATED;
import static io.undertow.util.Methods.POST;
import static io.undertow.util.StatusCodes.OK;

import io.undertow.security.api.AuthenticationMechanism;
import io.undertow.security.api.AuthenticationMechanismFactory;
import io.undertow.security.api.SecurityContext;
import io.undertow.security.api.NotificationReceiver;
import io.undertow.security.api.SecurityNotification;
import io.undertow.security.idm.IdentityManager;
import io.undertow.security.impl.FormAuthenticationMechanism;
import io.undertow.server.HttpServerExchange;
Expand Down Expand Up @@ -153,6 +158,21 @@ public ServletFormAuthenticationMechanism(FormParserFactory formParserFactory, S
this.overrideInitial = overrideInitial;
}

@Override
public AuthenticationMechanismOutcome authenticate(final HttpServerExchange exchange, final SecurityContext securityContext) {
if (POST.equals(exchange.getRequestMethod())) {
securityContext.registerNotificationReceiver(new NotificationReceiver() {
@Override
public void handleNotification(final SecurityNotification notification) {
if (notification.getEventType() == AUTHENTICATED) {
getAndInitializeSession(exchange, false);
}
}
});
}
return super.authenticate(exchange, securityContext);
}

@Override
protected Integer servePage(final HttpServerExchange exchange, final String location) {
final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);
Expand Down Expand Up @@ -195,27 +215,7 @@ protected void storeInitialLocation(final HttpServerExchange exchange, byte[] by
if(!saveOriginalRequest) {
return;
}
final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);
final ServletContextImpl servletContextImpl = servletRequestContext.getCurrentServletContext();
HttpSessionImpl httpSession = servletContextImpl.getSession(exchange, false);
boolean newSession = false;
if (httpSession == null) {
httpSession = servletContextImpl.getSession(exchange, true);
newSession = true;
}
Session session;
if (System.getSecurityManager() == null) {
session = httpSession.getSession();
} else {
session = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(httpSession));
}
if (newSession) {
int originalMaxInactiveInterval = session.getMaxInactiveInterval();
if (originalMaxInactiveInterval > authenticationSessionTimeout) {
session.setAttribute(ORIGINAL_SESSION_TIMEOUT, session.getMaxInactiveInterval());
session.setMaxInactiveInterval(authenticationSessionTimeout);
}
}
Session session = getAndInitializeSession(exchange, true);
SessionManager manager = session.getSessionManager();
if (seenSessionManagers.add(manager)) {
manager.registerSessionListener(LISTENER);
Expand All @@ -230,33 +230,57 @@ protected void storeInitialLocation(final HttpServerExchange exchange, byte[] by

@Override
protected void handleRedirectBack(final HttpServerExchange exchange) {
final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);
HttpServletResponse resp = (HttpServletResponse) servletRequestContext.getServletResponse();
HttpSessionImpl httpSession = servletRequestContext.getCurrentServletContext().getSession(exchange, false);
if (httpSession != null) {
Session session;
if (System.getSecurityManager() == null) {
session = httpSession.getSession();
} else {
session = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(httpSession));
}
Integer originalSessionTimeout = (Integer) session.removeAttribute(ORIGINAL_SESSION_TIMEOUT);
if (originalSessionTimeout != null) {
session.setMaxInactiveInterval(originalSessionTimeout);
}
final Session session = getAndInitializeSession(exchange, false);
if (session != null) {
String path = (String) session.getAttribute(SESSION_KEY);
if ((path == null || overrideInitial) && defaultPage != null) {
path = defaultPage;
}
if (path != null) {
final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);
final HttpServletResponse resp = (HttpServletResponse) servletRequestContext.getServletResponse();
try {
resp.sendRedirect(path);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
}

private Session getAndInitializeSession(final HttpServerExchange exchange, final boolean createNewSession) {
final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);
final ServletContextImpl servletContextImpl = servletRequestContext.getCurrentServletContext();
HttpSessionImpl httpSession = servletContextImpl.getSession(exchange, false);
if (httpSession == null && !createNewSession) return null;

boolean newSession = false;
if (httpSession == null) {
httpSession = servletContextImpl.getSession(exchange, true);
newSession = true;
}

Session session;
if (System.getSecurityManager() == null) {
session = httpSession.getSession();
} else {
session = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(httpSession));
}

if (newSession) {
final int originalMaxInactiveInterval = session.getMaxInactiveInterval();
if (originalMaxInactiveInterval > authenticationSessionTimeout) {
session.setAttribute(ORIGINAL_SESSION_TIMEOUT, session.getMaxInactiveInterval());
session.setMaxInactiveInterval(authenticationSessionTimeout);
}
} else {
final Integer originalSessionTimeout = (Integer) session.removeAttribute(ORIGINAL_SESSION_TIMEOUT);
if (originalSessionTimeout != null) {
session.setMaxInactiveInterval(originalSessionTimeout);
}
}

return session;
}

private static class FormResponseWrapper extends HttpServletResponseWrapper {
Expand Down