-
Notifications
You must be signed in to change notification settings - Fork 5
/
Http4sServlet.scala
198 lines (181 loc) · 6.81 KB
/
Http4sServlet.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
/*
* Copyright 2013 http4s.org
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.http4s.servlet
import cats.effect.kernel.Async
import cats.effect.kernel.Sync
import cats.effect.std.Dispatcher
import cats.syntax.all._
import com.comcast.ip4s.IpAddress
import com.comcast.ip4s.Port
import com.comcast.ip4s.SocketAddress
import org.http4s._
import org.http4s.internal.CollectionCompat.CollectionConverters._
import org.http4s.server.SecureSession
import org.http4s.server.ServerRequestKeys
import org.log4s.Logger
import org.log4s.getLogger
import org.typelevel.ci._
import org.typelevel.vault._
import java.security.cert.X509Certificate
import javax.servlet.ServletConfig
import javax.servlet.http.HttpServlet
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
import javax.servlet.http.HttpSession
abstract class Http4sServlet[F[_]](
service: HttpApp[F],
servletIo: ServletIo[F],
dispatcher: Dispatcher[F],
)(implicit F: Sync[F])
extends HttpServlet {
@deprecated("Binary compatibility", "0.23.12")
private[servlet] def this(
service: HttpApp[F],
servletIo: ServletIo[F],
dispatcher: Dispatcher[F],
async: Async[F],
) = this(service, servletIo, dispatcher)(async: Sync[F])
protected val logger: Logger = getLogger
// micro-optimization: unwrap the service and call its .run directly
protected val serviceFn: Request[F] => F[Response[F]] = service.run
protected var servletApiVersion: ServletApiVersion = _
private[this] var serverSoftware: ServerSoftware = _
object ServletRequestKeys {
val HttpSession: Key[Option[HttpSession]] = {
val result = Key.newKey[F, Option[HttpSession]]
dispatcher.unsafeRunSync(result)
}
}
override def init(config: ServletConfig): Unit = {
super.init(config)
val servletContext = config.getServletContext
servletApiVersion = ServletApiVersion(servletContext)
logger.info(s"Detected Servlet API version $servletApiVersion")
serverSoftware = ServerSoftware(servletContext.getServerInfo)
}
@deprecated("Use the overload without bodyWriter.", "0.23.15")
protected def onParseFailure(
parseFailure: ParseFailure,
servletResponse: HttpServletResponse,
bodyWriter: BodyWriter[F],
): F[Unit] =
onParseFailure(parseFailure, servletResponse)
protected def onParseFailure(
parseFailure: ParseFailure,
servletResponse: HttpServletResponse,
): F[Unit] =
F.delay(servletResponse.sendError(Status.BadRequest.code, parseFailure.sanitized))
protected def renderResponse(
response: Response[F],
servletResponse: HttpServletResponse,
bodyWriter: BodyWriter[F],
): F[Unit] =
// Note: the servlet API gives us no undeprecated method to both set
// a body and a status reason. We sacrifice the status reason.
//
// This F.attempt.flatMap can be interrupted, which prevents the body from
// running, which prevents the response from finalizing. Woe betide you if
// your effect isn't Concurrent.
F.delay {
servletResponse.setStatus(response.status.code)
for (header <- response.headers.headers if header.name != ci"Transfer-Encoding")
servletResponse.addHeader(header.name.toString, header.value)
}.attempt
.flatMap {
case Right(()) => bodyWriter(response)
case Left(t) =>
response.body.compile.drain.handleError { t2 =>
logger.error(t2)("Error draining body")
} *> F.raiseError(t)
}
protected def toRequest(req: HttpServletRequest): ParseResult[Request[F]] =
for {
method <- Method.fromString(req.getMethod)
uri <- Uri.requestTarget(
Option(req.getQueryString)
.map { q =>
s"${req.getRequestURI}?$q"
}
.getOrElse(req.getRequestURI)
)
version <- HttpVersion.fromString(req.getProtocol)
pathInfoIndex <- getPathInfoIndex(req, uri)
attributes <- ParseResult.fromTryCatchNonFatal("")(
Vault.empty
.insert(Request.Keys.PathInfoCaret, pathInfoIndex)
.insert(
Request.Keys.ConnectionInfo,
Request.Connection(
local = SocketAddress(
IpAddress.fromString(stripBracketsFromAddr(req.getLocalAddr)).get,
Port.fromInt(req.getLocalPort).get,
),
remote = SocketAddress(
IpAddress.fromString(stripBracketsFromAddr(req.getRemoteAddr)).get,
Port.fromInt(req.getRemotePort).get,
),
secure = req.isSecure,
),
)
.insert(Request.Keys.ServerSoftware, serverSoftware)
.insert(ServletRequestKeys.HttpSession, Option(req.getSession(false)))
.insert(
ServerRequestKeys.SecureSession,
(
Option(req.getAttribute("javax.servlet.request.ssl_session_id").asInstanceOf[String]),
Option(req.getAttribute("javax.servlet.request.cipher_suite").asInstanceOf[String]),
Option(req.getAttribute("javax.servlet.request.key_size").asInstanceOf[Int]),
Option(
req
.getAttribute("javax.servlet.request.X509Certificate")
.asInstanceOf[Array[X509Certificate]]
),
)
.mapN(SecureSession.apply),
)
)
} yield Request(
method = method,
uri = uri,
httpVersion = version,
headers = toHeaders(req),
body = servletIo.reader(req),
attributes = attributes,
)
private def getPathInfoIndex(req: HttpServletRequest, uri: Uri): ParseResult[Int] = {
val prefix =
Uri.Path
.unsafeFromString(req.getContextPath)
.concat(Uri.Path.unsafeFromString(req.getServletPath))
uri.path
.findSplit(prefix)
.toRight(
ParseFailure(
uri.path.renderString,
s"Couldn't find pathInfoIndex given the contextPath='${req.getContextPath}' and servletPath='${req.getServletPath}'.",
)
)
}
protected def toHeaders(req: HttpServletRequest): Headers = {
val headers = for {
name <- req.getHeaderNames.asScala
value <- req.getHeaders(name).asScala
} yield name -> value
Headers(headers.toList)
}
private final def stripBracketsFromAddr(addr: String): String =
addr.stripPrefix("[").stripSuffix("]")
}