Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement WebSocket support based on Plug.Conn.Adapter.upgrade/3 #88

Merged
merged 9 commits into from Oct 31, 2022
28 changes: 28 additions & 0 deletions lib/plug/cowboy.ex
Expand Up @@ -125,6 +125,34 @@ defmodule Plug.Cowboy do

To opt-out of this default instrumentation, you can manually configure
cowboy with the option `stream_handlers: [:cowboy_stream_h]`.

## WebSocket support

Plug.Cowboy supports upgrading HTTP requests to WebSocket connections via
the use of the `Plug.Conn.upgrade_adapter/3` function, called with `:websocket` as the second
argument. Applications should validate that the connection represents a valid WebSocket request
before calling this function (Cowboy will validate the connection as part of the upgrade
process, but does not provide any capacity for an application to be notified if the upgrade is
not successful). If an application wishes to negotiate WebSocket subprotocols or otherwise set
any response headers, it should do so before calling `Plug.Conn.upgrade_adapter/3`.

The third argument to `Plug.Conn.upgrade_adapter/3` defines the details of how Plug.Cowboy
should handle the WebSocket connection, and must take the form `{handler, handler_opts,
connection_opts}`, where values are as follows:
josevalim marked this conversation as resolved.
Show resolved Hide resolved

* `handler` is a module which implements the
[`:cowboy_websocket`](https://ninenines.eu/docs/en/cowboy/2.6/manual/cowboy_websocket/)
behaviour. Note that this module will NOT have its `c:cowboy_websocket.init/2` callback
called; only the 'later' parts of the `:cowboy_websocket` lifecycle are supported
* `handler_opts` is an arbitrary term which will be passed as the argument to
`c:cowboy_websocket.websocket_init/1`
* `connection_opts` is a keyword list which consists of zero or more of the following options:
* `idle_timeout`: The number of milliseconds to wait after no client data is received before
closing the connection. Defaults to `60_000`
* `compress`: Whether or not to attempt negotiation of a compression extension with the
client. Defaults to `false`
* `max_frame_size`: The maximum frame size in bytes to accept from the client. The connection
will be closed if the client sends a frame larger than this suze. Defaults to `:infinity`
mtrudel marked this conversation as resolved.
Show resolved Hide resolved
"""

require Logger
Expand Down
7 changes: 7 additions & 0 deletions lib/plug/cowboy/conn.ex
Expand Up @@ -90,6 +90,13 @@ defmodule Plug.Cowboy.Conn do
:cowboy_req.inform(status, to_headers_map(headers), req)
end

@impl true
def upgrade(req, :websocket, {_handler, _state, _cowboy_opts} = args) do
mtrudel marked this conversation as resolved.
Show resolved Hide resolved
{:ok, Map.put(req, :upgrade, {:websocket, args})}
end

def upgrade(_req, _protocol, _args), do: {:error, :not_supported}

@impl true
def push(req, path, headers) do
opts =
Expand Down
25 changes: 20 additions & 5 deletions lib/plug/cowboy/handler.ex
Expand Up @@ -7,12 +7,17 @@ defmodule Plug.Cowboy.Handler do
conn = @connection.conn(req)

try do
%{adapter: {@connection, req}} =
conn
|> plug.call(opts)
|> maybe_send(plug)
conn
|> plug.call(opts)
|> maybe_send(plug)
|> case do
%Plug.Conn{adapter: {@connection, %{upgrade: {:websocket, websocket_args}} = req}} = conn ->
{handler, state, cowboy_opts} = websocket_args
{__MODULE__, copy_resp_headers(conn, req), {handler, state}, cowboy_opts}

{:ok, req, {plug, opts}}
%Plug.Conn{adapter: {@connection, req}} ->
{:ok, req, {plug, opts}}
end
catch
kind, reason ->
exit_on_error(kind, reason, __STACKTRACE__, {plug, :call, [conn, opts]})
Expand All @@ -25,6 +30,16 @@ defmodule Plug.Cowboy.Handler do
end
end

def upgrade(req, env, __MODULE__, {handler, state}, opts) do
:cowboy_websocket.upgrade(req, env, handler, state, opts)
end

defp copy_resp_headers(%Plug.Conn{} = conn, req) do
Enum.reduce(conn.resp_headers, req, fn {key, val}, acc ->
:cowboy_req.set_resp_header(key, val, acc)
end)
end

defp exit_on_error(
:error,
%Plug.Conn.WrapperError{kind: kind, reason: reason, stack: stack},
Expand Down
2 changes: 1 addition & 1 deletion mix.exs
Expand Up @@ -33,7 +33,7 @@ defmodule Plug.Cowboy.MixProject do

def deps do
[
{:plug, "~> 1.7"},
{:plug, "~> 1.14"},
{:cowboy, "~> 2.7"},
{:cowboy_telemetry, "~> 0.3"},
{:ex_doc, "~> 0.20", only: :docs},
Expand Down
8 changes: 4 additions & 4 deletions mix.lock
Expand Up @@ -11,12 +11,12 @@
"makeup": {:hex, :makeup, "1.0.5", "d5a830bc42c9800ce07dd97fa94669dfb93d3bf5fcf6ea7a0c67b2e0e4a7f26c", [:mix], [{:nimble_parsec, "~> 0.5 or ~> 1.0", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cfa158c02d3f5c0c665d0af11512fed3fba0144cf1aadee0f2ce17747fba2ca9"},
"makeup_elixir": {:hex, :makeup_elixir, "0.15.1", "b5888c880d17d1cc3e598f05cdb5b5a91b7b17ac4eaf5f297cb697663a1094dd", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.1", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "db68c173234b07ab2a07f645a5acdc117b9f99d69ebf521821d89690ae6c6ec8"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"},
"mime": {:hex, :mime, "1.5.0", "203ef35ef3389aae6d361918bf3f952fa17a09e8e43b5aa592b93eba05d0fb8d", [:mix], [], "hexpm", "55a94c0f552249fc1a3dd9cd2d3ab9de9d3c89b559c2bd01121f824834f24746"},
"mime": {:hex, :mime, "2.0.3", "3676436d3d1f7b81b5a2d2bd8405f412c677558c81b1c92be58c00562bb59095", [:mix], [], "hexpm", "27a30bf0db44d25eecba73755acf4068cbfe26a4372f9eb3e4ea3a45956bff6b"},
"nimble_parsec": {:hex, :nimble_parsec, "1.1.0", "3a6fca1550363552e54c216debb6a9e95bd8d32348938e13de5eda962c0d7f89", [:mix], [], "hexpm", "08eb32d66b706e913ff748f11694b17981c0b04a33ef470e33e11b3d3ac8f54b"},
"plug": {:hex, :plug, "1.11.0", "f17217525597628298998bc3baed9f8ea1fa3f1160aa9871aee6df47a6e4d38e", [:mix], [{:mime, "~> 1.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "2d9c633f0499f9dc5c2fd069161af4e2e7756890b81adcbb2ceaa074e8308876"},
"plug_crypto": {:hex, :plug_crypto, "1.2.0", "1cb20793aa63a6c619dd18bb33d7a3aa94818e5fd39ad357051a67f26dfa2df6", [:mix], [], "hexpm", "a48b538ae8bf381ffac344520755f3007cc10bd8e90b240af98ea29b69683fc2"},
"plug": {:hex, :plug, "1.14.0", "ba4f558468f69cbd9f6b356d25443d0b796fbdc887e03fa89001384a9cac638f", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "bf020432c7d4feb7b3af16a0c2701455cbbbb95e5b6866132cb09eb0c29adc14"},
"plug_crypto": {:hex, :plug_crypto, "1.2.3", "8f77d13aeb32bfd9e654cb68f0af517b371fb34c56c9f2b58fe3df1235c1251a", [:mix], [], "hexpm", "b5672099c6ad5c202c45f5a403f21a3411247f164e4a8fab056e5cd8a290f4a2"},
"ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"},
"ssl_verify_hostname": {:hex, :ssl_verify_hostname, "1.0.6", "45866d958d9ae51cfe8fef0050ab8054d25cba23ace43b88046092aa2c714645", [:make], [], "hexpm", "72b2fc8a8e23d77eed4441137fefa491bbf4a6dc52e9c0045f3f8e92e66243b5"},
"telemetry": {:hex, :telemetry, "0.4.2", "2808c992455e08d6177322f14d3bdb6b625fbcfd233a73505870d8738a2f4599", [:rebar3], [], "hexpm", "2d1419bd9dda6a206d7b5852179511722e2b18812310d304620c7bd92a13fcef"},
"telemetry": {:hex, :telemetry, "0.4.3", "a06428a514bdbc63293cd9a6263aad00ddeb66f608163bdec7c8995784080818", [:rebar3], [], "hexpm", "eb72b8365ffda5bed68a620d1da88525e326cb82a75ee61354fc24b844768041"},
"x509": {:hex, :x509, "0.6.0", "51a274a8368cf6fe771c1920ce5b075d28a076745987666593adb818aadf1bf7", [:mix], [], "hexpm", "111d29a5388e059413aa4553b64f041452fe49777be310622cd71e451d29630d"},
}
63 changes: 63 additions & 0 deletions test/plug/cowboy/conn_test.exs
Expand Up @@ -340,6 +340,69 @@ defmodule Plug.Cowboy.ConnTest do
request(:get, "/inform")
end

def upgrade_unsupported(conn) do
conn
|> upgrade_adapter(:unsupported, opt: :unsupported)
end

test "upgrade will not set the response" do
assert {500, _, body} = request(:get, "/upgrade_unsupported")
assert body =~ "upgrade to unsupported not supported by Plug.Cowboy.Conn"
end

defmodule NoopWebSocketHandler do
@behaviour :cowboy_websocket

# We never actually call this; it's just here to quell compiler warnings
@impl true
def init(req, state), do: {:cowboy_websocket, req, state}

@impl true
def websocket_handle(_frame, state), do: {:ok, state}

@impl true
def websocket_info(_msg, state), do: {:ok, state}
end

def upgrade_websocket(conn) do
# In actual use, it's the caller's responsibility to ensure the upgrade is valid before
# calling upgrade_adapter
conn
|> upgrade_adapter(:websocket, {NoopWebSocketHandler, [], %{}})
end

test "upgrades the connection when the connection is a valid websocket" do
{:ok, socket} = :gen_tcp.connect('localhost', 8003, active: false, mode: :binary)

:gen_tcp.send(socket, """
GET /upgrade_websocket HTTP/1.1\r
Host: server.example.com\r
Upgrade: websocket\r
Connection: Upgrade\r
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r
Sec-WebSocket-Version: 13\r
\r
""")

{:ok, response} = :gen_tcp.recv(socket, 234)

assert [
"HTTP/1.1 101 Switching Protocols",
"cache-control: max-age=0, private, must-revalidate",
"connection: Upgrade",
"date: " <> _date,
"sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"server: Cowboy",
"upgrade: websocket",
"",
""
] = String.split(response, "\r\n")
end

test "returns error in cases where an upgrade is indicated but the connection is not a valid upgrade" do
assert {426, _headers, ""} = request(:get, "/upgrade_websocket")
end

def push(conn) do
conn
|> push("/static/assets.css")
Expand Down
167 changes: 167 additions & 0 deletions test/plug/cowboy/websocket_handler_test.exs
@@ -0,0 +1,167 @@
defmodule WebSocketHandlerTest do
use ExUnit.Case, async: true

defmodule WebSocketHandler do
@behaviour :cowboy_websocket

# We never actually call this; it's just here to quell compiler warnings
@impl true
def init(req, state), do: {:cowboy_websocket, req, state}

@impl true
def websocket_init(_opts), do: {:ok, :init}

@impl true
def websocket_handle({:text, "state"}, state), do: {[{:text, inspect(state)}], state}

def websocket_handle({:text, "whoami"}, state),
do: {[{:text, :erlang.pid_to_list(self())}], state}

@impl true
def websocket_info(msg, state), do: {[{:text, inspect(msg)}], state}
end

@protocol_options [
idle_timeout: 1000,
request_timeout: 1000
]

setup_all do
{:ok, _} = Plug.Cowboy.http(__MODULE__, [], port: 8083, protocol_options: @protocol_options)
on_exit(fn -> :ok = Plug.Cowboy.shutdown(__MODULE__.HTTP) end)
{:ok, port: 8083}
end

@behaviour Plug

@impl Plug
def init(arg), do: arg

@impl Plug
def call(conn, _opts) do
conn = Plug.Conn.fetch_query_params(conn)
handler = conn.query_params["handler"] |> String.to_atom()
Plug.Conn.upgrade_adapter(conn, :websocket, {handler, [], %{idle_timeout: 1000}})
end

test "websocket_init and websocket_handle are called", context do
client = tcp_client(context)
http1_handshake(client, WebSocketHandler)

send_text_frame(client, "state")
{:ok, result} = recv_text_frame(client)
assert result == inspect(:init)
end

test "websocket_info is called", context do
client = tcp_client(context)
http1_handshake(client, WebSocketHandler)

send_text_frame(client, "whoami")
{:ok, pid} = recv_text_frame(client)
pid = pid |> String.to_charlist() |> :erlang.list_to_pid()

Process.send(pid, "hello info", [])

{:ok, response} = recv_text_frame(client)
assert response == inspect("hello info")
end

# Simple WebSocket client

def tcp_client(context) do
{:ok, socket} = :gen_tcp.connect('localhost', context[:port], active: false, mode: :binary)

socket
end

def http1_handshake(client, module, params \\ []) do
params = params |> Keyword.put(:handler, module)

:gen_tcp.send(client, """
GET /?#{URI.encode_query(params)} HTTP/1.1\r
Host: server.example.com\r
Upgrade: websocket\r
Connection: Upgrade\r
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r
Sec-WebSocket-Version: 13\r
\r
""")

{:ok, response} = :gen_tcp.recv(client, 234)

[
"HTTP/1.1 101 Switching Protocols",
"cache-control: max-age=0, private, must-revalidate",
"connection: Upgrade",
"date: " <> _date,
"sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"server: Cowboy",
"upgrade: websocket",
"",
""
] = String.split(response, "\r\n")
end

defp recv_text_frame(client) do
{:ok, 0x8, 0x1, body} = recv_frame(client)
{:ok, body}
end

defp recv_frame(client) do
{:ok, header} = :gen_tcp.recv(client, 2)
<<flags::4, opcode::4, 0::1, length::7>> = header

{:ok, data} =
case length do
0 ->
{:ok, <<>>}

126 ->
{:ok, <<length::16>>} = :gen_tcp.recv(client, 2)
:gen_tcp.recv(client, length)

127 ->
{:ok, <<length::64>>} = :gen_tcp.recv(client, 8)
:gen_tcp.recv(client, length)

length ->
:gen_tcp.recv(client, length)
end

{:ok, flags, opcode, data}
end

defp send_text_frame(client, data, flags \\ 0x8) do
send_frame(client, flags, 0x1, data)
end

defp send_frame(client, flags, opcode, data) do
mask = :rand.uniform(1_000_000)
masked_data = mask(data, mask)

mask_flag_and_size =
case byte_size(masked_data) do
size when size <= 125 -> <<1::1, size::7>>
size when size <= 65_535 -> <<1::1, 126::7, size::16>>
size -> <<1::1, 127::7, size::64>>
end

:gen_tcp.send(client, [<<flags::4, opcode::4>>, mask_flag_and_size, <<mask::32>>, masked_data])
end

# Note that masking is an involution, so we don't need a separate unmask function
defp mask(payload, mask, acc \\ <<>>)

defp mask(payload, mask, acc) when is_integer(mask), do: mask(payload, <<mask::32>>, acc)

defp mask(<<h::32, rest::binary>>, <<mask::32>>, acc) do
mask(rest, mask, acc <> <<Bitwise.bxor(h, mask)::32>>)
end

defp mask(<<h::8, rest::binary>>, <<current::8, mask::24>>, acc) do
mask(rest, <<mask::24, current::8>>, acc <> <<Bitwise.bxor(h, current)::8>>)
end

defp mask(<<>>, _mask, acc), do: acc
end