Skip to content

Commit

Permalink
Implement WebSocket upgrade support via the Sock API
Browse files Browse the repository at this point in the history
  • Loading branch information
mtrudel committed Oct 16, 2022
1 parent 8c69a4f commit a14f4e2
Show file tree
Hide file tree
Showing 7 changed files with 1,031 additions and 13 deletions.
24 changes: 24 additions & 0 deletions lib/plug/cowboy.ex
Expand Up @@ -125,6 +125,30 @@ defmodule Plug.Cowboy do
To opt-out of this default instrumentation, you can manually configure
cowboy with the option `stream_handlers: [:cowboy_stream_h]`.
## WebSocket support
Plug.Cowboy supports upgrading HTTP requests to WebSocket connections via
the use of the `Plug.Conn.upgrade_adapter/3` function, called with `:websocket` as the second
argument. Applications should validate that the connection represents a valid WebSocket request
before calling this function (Plug.Cowboy will validate the connection as part of the upgrade
process, but does not provide any capacity for an application to be notified if the upgrade is
not successful). If an application wishes to negotiate WebSocket subprotocols or otherwise set
any response headers, it should do so before calling `Plug.Conn.upgrade_adapter/3`.
The third argument to `Plug.Conn.upgrade_adapter/3` defines the details of how Plug.Cowboy
should handle the WebSocket connection, and must take the form `{handler, handler_opts,
connection_opts}`, where values are as follows:
* `handler` is a module which implements the `Sock` API
* `handler_opts` is an arbitrary term which will be passed as the argument to `c:Sock.init/1`
* `connection_opts` is a keyword list which consists of zero or more of the following options:
* `timeout`: The number of milliseconds to wait after no client data is received before
closing the connection. Defaults to `60_000`
* `compress`: Whether or not to attempt negotiation of a compression extension with the
client. Defaults to `false`
* `max_frame_size`: The maximum frame size in bytes to accept from the client. The connection
will be closed if the client sends a frame larger than this suze. Defaults to `:infinity`
"""

require Logger
Expand Down
6 changes: 3 additions & 3 deletions lib/plug/cowboy/conn.ex
Expand Up @@ -91,9 +91,9 @@ defmodule Plug.Cowboy.Conn do
end

@impl true
def upgrade(_req, _protocol, _opts) do
{:error, :not_supported}
end
def upgrade(req, :websocket, opts), do: {:ok, req |> Map.put(:upgrade, {:websocket, opts})}

def upgrade(_req, _upgrade, _opts), do: {:error, :not_supported}

@impl true
def push(req, path, headers) do
Expand Down
120 changes: 115 additions & 5 deletions lib/plug/cowboy/handler.ex
@@ -1,18 +1,42 @@
defmodule Plug.Cowboy.Handler do
@moduledoc false

if Code.ensure_loaded?(:cowboy_websocket) and
function_exported?(:cowboy_websocket, :behaviour_info, 1) do
@behaviour :cowboy_websocket
end

@connection Plug.Cowboy.Conn
@already_sent {:plug_conn, :sent}

def init(req, {plug, opts}) do
conn = @connection.conn(req)

try do
%{adapter: {@connection, req}} =
conn
|> plug.call(opts)
|> maybe_send(plug)
conn
|> plug.call(opts)
|> maybe_send(plug)
|> case do
%Plug.Conn{adapter: {@connection, %{upgrade: {:websocket, websocket_opts}} = req}} = conn ->
{handler, state, connection_opts} = websocket_opts

{:ok, req, {plug, opts}}
cowboy_opts =
connection_opts
|> Enum.flat_map(fn
{:timeout, timeout} -> [idle_timeout: timeout]
{:compress, _} = opt -> [opt]
{:max_frame_size, _} = opt -> [opt]
_other -> []
end)
|> Map.new()

handler_opts = Keyword.take(connection_opts, [:fullsweep_after])
triplet = {handler, handler_opts, state}
{:cowboy_websocket, copy_resp_headers(conn, req), triplet, cowboy_opts}

%Plug.Conn{adapter: {@connection, req}} ->
{:ok, req, {plug, opts}}
end
catch
kind, reason ->
exit_on_error(kind, reason, __STACKTRACE__, {plug, :call, [conn, opts]})
Expand All @@ -25,6 +49,12 @@ defmodule Plug.Cowboy.Handler do
end
end

defp copy_resp_headers(%Plug.Conn{} = conn, req) do
Enum.reduce(conn.resp_headers, req, fn {key, val}, acc ->
:cowboy_req.set_resp_header(key, val, acc)
end)
end

defp exit_on_error(
:error,
%Plug.Conn.WrapperError{kind: kind, reason: reason, stack: stack},
Expand Down Expand Up @@ -55,4 +85,84 @@ defmodule Plug.Cowboy.Handler do
raise "Cowboy2 adapter expected #{inspect(plug)} to return Plug.Conn but got: " <>
inspect(other)
end

## Websocket callbacks

def websocket_init({handler, process_flags, state}) do
for {key, value} <- process_flags do
:erlang.process_flag(key, value)
end

handle_reply(handler, handler.init(state))
end

def websocket_handle({opcode, payload}, {:sock, handler, state})
when opcode in [:text, :binary] do
handle_reply(handler, handler.handle_in({payload, opcode: opcode}, state))
end

def websocket_handle({opcode, payload}, handler_state) when opcode in [:ping, :pong] do
handle_control_frame({payload, opcode: opcode}, handler_state)
end

def websocket_handle(opcode, handler_state) when opcode in [:ping, :pong] do
handle_control_frame({nil, opcode: opcode}, handler_state)
end

def websocket_handle(_other, handler_state) do
{:ok, handler_state}
end

def websocket_info(message, {:sock, handler, state}) do
handle_reply(handler, handler.handle_info(message, state))
end

defp handle_reply(handler, {:ok, state}), do: {:ok, {:sock, handler, state}}
defp handle_reply(handler, {:push, data, state}), do: {:reply, data, {:sock, handler, state}}

defp handle_reply(handler, {:reply, _status, data, state}),
do: {:reply, data, {:sock, handler, state}}

defp handle_reply(handler, {:stop, _reason, state}), do: {:stop, {:sock, handler, state}}

defp handle_control_frame(payload_with_opts, {:sock, handler, state}) do
reply =
if function_exported?(handler, :handle_control, 2) do
handler.handle_control(payload_with_opts, state)
else
{:ok, state}
end

handle_reply(handler, reply)
end

def terminate({:remote, :closed}, _req, {:sock, handler, state}) do
handler.terminate(:closed, state)
end

def terminate({:remote, code, _}, _req, {:sock, handler, state})
when code in 1000..1003 or code in 1005..1011 or code == 1015 do
handler.terminate(:remote, state)
end

def terminate(:remote, _req, {:sock, handler, state}) do
handler.terminate(:remote, state)
end

def terminate({:error, reason}, _req, {:sock, handler, state}) do
handler.terminate({:error, reason}, state)
end

def terminate(:stop, _req, {:sock, handler, state}) do
handler.terminate(:normal, state)
end

def terminate(reason, _req, {:sock, handler, state}) do
handler.terminate(reason, state)
end

# Note that this terminate/3 function is part of the Cowboy Handler API, not cowboy_websocket
def terminate(_reason, _req, _state) do
:ok
end
end
1 change: 1 addition & 0 deletions mix.exs
Expand Up @@ -34,6 +34,7 @@ defmodule Plug.Cowboy.MixProject do
def deps do
[
{:plug, "~> 1.7"},
{:sock, "~> 0.3.0"},
{:cowboy, "~> 2.7"},
{:cowboy_telemetry, "~> 0.3"},
{:ex_doc, "~> 0.20", only: :docs},
Expand Down
1 change: 1 addition & 0 deletions mix.lock
Expand Up @@ -16,6 +16,7 @@
"plug": {:hex, :plug, "1.11.0", "f17217525597628298998bc3baed9f8ea1fa3f1160aa9871aee6df47a6e4d38e", [:mix], [{:mime, "~> 1.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "2d9c633f0499f9dc5c2fd069161af4e2e7756890b81adcbb2ceaa074e8308876"},
"plug_crypto": {:hex, :plug_crypto, "1.2.0", "1cb20793aa63a6c619dd18bb33d7a3aa94818e5fd39ad357051a67f26dfa2df6", [:mix], [], "hexpm", "a48b538ae8bf381ffac344520755f3007cc10bd8e90b240af98ea29b69683fc2"},
"ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"},
"sock": {:hex, :sock, "0.3.0", "026496b408061a7ceea7486ed6d99439b74a82f26bbdd44bdbc78a9a80cfbe7b", [:mix], [], "hexpm", "27432ee0598f16068fba49c1436a39932c984e9fdbd4da870c0b2f9a1bb7d453"},
"ssl_verify_hostname": {:hex, :ssl_verify_hostname, "1.0.6", "45866d958d9ae51cfe8fef0050ab8054d25cba23ace43b88046092aa2c714645", [:make], [], "hexpm", "72b2fc8a8e23d77eed4441137fefa491bbf4a6dc52e9c0045f3f8e92e66243b5"},
"telemetry": {:hex, :telemetry, "0.4.2", "2808c992455e08d6177322f14d3bdb6b625fbcfd233a73505870d8738a2f4599", [:rebar3], [], "hexpm", "2d1419bd9dda6a206d7b5852179511722e2b18812310d304620c7bd92a13fcef"},
"x509": {:hex, :x509, "0.6.0", "51a274a8368cf6fe771c1920ce5b075d28a076745987666593adb818aadf1bf7", [:mix], [], "hexpm", "111d29a5388e059413aa4553b64f041452fe49777be310622cd71e451d29630d"},
Expand Down
65 changes: 60 additions & 5 deletions test/plug/cowboy/conn_test.exs
Expand Up @@ -340,14 +340,69 @@ defmodule Plug.Cowboy.ConnTest do
request(:get, "/inform")
end

def upgrade(conn) do
def upgrade_unsupported(conn) do
conn
|> upgrade_adapter(:unsupported, opt: :unsupported)
|> send_resp(200, "upgrade")
|> upgrade_adapter(:unsupported, nil)
|> send_resp(200, "Not supported")
end

test "upgrade will not set the response" do
assert {200, _headers, "upgrade"} = request(:get, "/upgrade")
test "does not update the conn or send any data on unsupported upgrades" do
assert {200, _headers, "Not supported"} = request(:get, "/upgrade_unsupported")
end

defmodule NoopSock do
@behaviour Sock

@impl true
def init(arg), do: {:ok, arg}

@impl true
def handle_in(_data, state), do: {:ok, state}

@impl true
def handle_info(_msg, state), do: {:ok, state}

@impl true
def terminate(_reason, _state), do: :ok
end

def upgrade_websocket(conn) do
# In actual use, it's the caller's responsibility to ensure the upgrade is valid before
# calling upgrade_adapter
conn
|> upgrade_adapter(:websocket, {NoopSock, [], []})
end

test "returns error in cases where an upgrade is indicated but the connection is not a valid upgrade" do
assert {426, _headers, ""} = request(:get, "/upgrade_websocket")
end

test "upgrades the connection when the connection is a valid websocket" do
{:ok, socket} = :gen_tcp.connect('localhost', 8003, active: false, mode: :binary)

:gen_tcp.send(socket, """
GET /upgrade_websocket HTTP/1.1\r
Host: server.example.com\r
Upgrade: websocket\r
Connection: Upgrade\r
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r
Sec-WebSocket-Version: 13\r
\r
""")

{:ok, response} = :gen_tcp.recv(socket, 234)

assert [
"HTTP/1.1 101 Switching Protocols",
"cache-control: max-age=0, private, must-revalidate",
"connection: Upgrade",
"date: " <> _date,
"sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"server: Cowboy",
"upgrade: websocket",
"",
""
] = String.split(response, "\r\n")
end

def push(conn) do
Expand Down

0 comments on commit a14f4e2

Please sign in to comment.