From 65351673e60c3452261279c2a341079b198335d6 Mon Sep 17 00:00:00 2001 From: "Ross A. Baker" Date: Fri, 30 Jun 2023 18:22:40 -0400 Subject: [PATCH] Render async timeouts in the listener, as Jetty 9 requires --- .../http4s/servlet/AsyncHttp4sServlet.scala | 40 ++++++++++++------- .../servlet/BlockingHttp4sServlet.scala | 2 +- .../org/http4s/servlet/Http4sServlet.scala | 13 ++++-- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala b/servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala index f93996cb..d97b2501 100644 --- a/servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala +++ b/servlet/src/main/scala/org/http4s/servlet/AsyncHttp4sServlet.scala @@ -62,17 +62,13 @@ class AsyncHttp4sServlet[F[_]] @deprecated("Use AsyncHttp4sServlet.builder", "0. ctx.setTimeout(asyncTimeoutMillis) // Must be done on the container thread for Tomcat's sake when using async I/O. val bodyWriter = servletIo.initWriter(servletResponse) - val result = F - .attempt( - toRequest(servletRequest).fold( - onParseFailure(_, servletResponse, bodyWriter), + val result = + toRequest(servletRequest) + .fold( + onParseFailure(_, servletResponse), handleRequest(ctx, _, bodyWriter), ) - ) - .flatMap { - case Right(()) => F.delay(ctx.complete) - case Left(t) => errorHandler(servletRequest, servletResponse)(t) - } + .recoverWith(errorHandler(servletRequest, servletResponse)) dispatcher.unsafeRunAndForget(result) } catch errorHandler(servletRequest, servletResponse).andThen(dispatcher.unsafeRunSync _) @@ -87,17 +83,23 @@ class AsyncHttp4sServlet[F[_]] @deprecated("Use AsyncHttp4sServlet.builder", "0. // It is an error to add a listener to an async context that is // already completed, so we must take care to add the listener // before the response can complete. - val timeout = - F.async[Response[F]](cb => + F.async[Unit](cb => gate.complete(ctx.addListener(new AsyncTimeoutHandler(cb))).as(noopCancelToken) ) val response = gate.get *> F.defer(serviceFn(request)) .recoverWith(serviceErrorHandler(request)) - val servletResponse = ctx.getResponse.asInstanceOf[HttpServletResponse] - F.race(timeout, response).flatMap(r => renderResponse(r.merge, servletResponse, bodyWriter)) + F.race(timeout, response).flatMap { + case Left(_) => + // In Jetty, if onTimeout is called, we need to complete on the + // listener's own thread. + F.unit + case Right(resp) => + val servletResponse = ctx.getResponse.asInstanceOf[HttpServletResponse] + renderResponse(resp, servletResponse, bodyWriter) *> F.delay(ctx.complete()) + } } private def errorHandler( @@ -124,11 +126,19 @@ class AsyncHttp4sServlet[F[_]] @deprecated("Use AsyncHttp4sServlet.builder", "0. } } - private class AsyncTimeoutHandler(cb: Callback[Response[F]]) extends AbstractAsyncListener { + private class AsyncTimeoutHandler(cb: Callback[Unit]) extends AbstractAsyncListener { override def onTimeout(event: AsyncEvent): Unit = { + // In Jetty, we must complete on the same thread as the timeout + // handler. This triggers a cancellation of the service so we + // can take over. + cb(Right(())) + + val ctx = event.getAsyncContext val req = event.getAsyncContext.getRequest.asInstanceOf[HttpServletRequest] logger.info(s"Request timed out: ${req.getMethod} ${req.getServletPath}${req.getPathInfo}") - cb(Right(Response.timeout[F])) + val resp = event.getAsyncContext.getResponse.asInstanceOf[HttpServletResponse] + resp.sendError(Response.timeout.status.code, "Response timed out") + ctx.complete() } } } diff --git a/servlet/src/main/scala/org/http4s/servlet/BlockingHttp4sServlet.scala b/servlet/src/main/scala/org/http4s/servlet/BlockingHttp4sServlet.scala index 3685fcae..9e0760ec 100644 --- a/servlet/src/main/scala/org/http4s/servlet/BlockingHttp4sServlet.scala +++ b/servlet/src/main/scala/org/http4s/servlet/BlockingHttp4sServlet.scala @@ -61,7 +61,7 @@ class BlockingHttp4sServlet[F[_]] private ( val bodyWriter = servletIo.initWriter(servletResponse) val render = toRequest(servletRequest).fold( - onParseFailure(_, servletResponse, bodyWriter), + onParseFailure(_, servletResponse), handleRequest(_, servletResponse, bodyWriter), ) diff --git a/servlet/src/main/scala/org/http4s/servlet/Http4sServlet.scala b/servlet/src/main/scala/org/http4s/servlet/Http4sServlet.scala index c6893266..0115ab18 100644 --- a/servlet/src/main/scala/org/http4s/servlet/Http4sServlet.scala +++ b/servlet/src/main/scala/org/http4s/servlet/Http4sServlet.scala @@ -77,14 +77,19 @@ abstract class Http4sServlet[F[_]]( serverSoftware = ServerSoftware(servletContext.getServerInfo) } + @deprecated("Use the overload without bodyWriter.", "0.23.15") protected def onParseFailure( parseFailure: ParseFailure, servletResponse: HttpServletResponse, bodyWriter: BodyWriter[F], - ): F[Unit] = { - val response = Response[F](Status.BadRequest).withEntity(parseFailure.sanitized) - renderResponse(response, servletResponse, bodyWriter) - } + ): F[Unit] = + onParseFailure(parseFailure, servletResponse) + + protected def onParseFailure( + parseFailure: ParseFailure, + servletResponse: HttpServletResponse, + ): F[Unit] = + F.delay(servletResponse.sendError(Status.BadRequest.code, parseFailure.sanitized)) protected def renderResponse( response: Response[F],