diff --git a/lib/plug/cowboy.ex b/lib/plug/cowboy.ex index 0728ab9..68c0ddd 100644 --- a/lib/plug/cowboy.ex +++ b/lib/plug/cowboy.ex @@ -125,6 +125,33 @@ defmodule Plug.Cowboy do To opt-out of this default instrumentation, you can manually configure cowboy with the option `stream_handlers: [:cowboy_stream_h]`. + + ## WebSocket support + + Plug.Cowboy supports upgrading HTTP requests to WebSocket connections via + the use of the `Plug.Conn.upgrade_adapter/3` function, called with `:websocket` as the second + argument. Applications should validate that the connection represents a valid WebSocket request + before calling this function (Cowboy will validate the connection as part of the upgrade + process, but does not provide any capacity for an application to be notified if the upgrade is + not successful). If an application wishes to negotiate WebSocket subprotocols or otherwise set + any response headers, it should do so before calling `Plug.Conn.upgrade_adapter/3`. + + The third argument to `Plug.Conn.upgrade_adapter/3` defines the details of how Plug.Cowboy + should handle the WebSocket connection, and must take the form `{handler, handler_opts, + connection_opts}`, where values are as follows: + + * `handler` is a module which implements the `:cowboy_websocket` behaviour. Note that this + * module will NOT have its `c:cowboy_websocket.init/2` callback called; only the 'later' parts + of the `:cowboy_websocket` lifecycle are supported + * `handler_opts` is an arbitrary term which will be passed as the argument to + `c:cowboy_websocket.websocket_init/1` + * `connection_opts` is a keyword list which consists of zero or more of the following options: + * `timeout`: The number of milliseconds to wait after no client data is received before + closing the connection. Defaults to `60_000` + * `compress`: Whether or not to attempt negotiation of a compression extension with the + client. Defaults to `false` + * `max_frame_size`: The maximum frame size in bytes to accept from the client. The connection + will be closed if the client sends a frame larger than this suze. Defaults to `:infinity` """ require Logger diff --git a/lib/plug/cowboy/conn.ex b/lib/plug/cowboy/conn.ex index b43fa1b..c8effc2 100644 --- a/lib/plug/cowboy/conn.ex +++ b/lib/plug/cowboy/conn.ex @@ -91,6 +91,7 @@ defmodule Plug.Cowboy.Conn do end @impl true + def upgrade(req, :websocket, args), do: {:ok, Map.put(req, :upgrade, {:websocket, args})} def upgrade(_req, _protocol, _args), do: {:error, :not_supported} @impl true diff --git a/lib/plug/cowboy/handler.ex b/lib/plug/cowboy/handler.ex index e7aed97..d5f39af 100644 --- a/lib/plug/cowboy/handler.ex +++ b/lib/plug/cowboy/handler.ex @@ -7,12 +7,28 @@ defmodule Plug.Cowboy.Handler do conn = @connection.conn(req) try do - %{adapter: {@connection, req}} = - conn - |> plug.call(opts) - |> maybe_send(plug) + conn + |> plug.call(opts) + |> maybe_send(plug) + |> case do + %Plug.Conn{adapter: {@connection, %{upgrade: {:websocket, websocket_args}} = req}} = conn -> + {handler, state, connection_opts} = websocket_args - {:ok, req, {plug, opts}} + cowboy_opts = + connection_opts + |> Enum.flat_map(fn + {:timeout, timeout} -> [idle_timeout: timeout] + {:compress, _} = opt -> [opt] + {:max_frame_size, _} = opt -> [opt] + _other -> [] + end) + |> Map.new() + + {__MODULE__, copy_resp_headers(conn, req), {handler, state}, cowboy_opts} + + %Plug.Conn{adapter: {@connection, req}} -> + {:ok, req, {plug, opts}} + end catch kind, reason -> exit_on_error(kind, reason, __STACKTRACE__, {plug, :call, [conn, opts]}) @@ -25,6 +41,16 @@ defmodule Plug.Cowboy.Handler do end end + def upgrade(req, env, __MODULE__, {handler, state}, opts) do + :cowboy_websocket.upgrade(req, env, handler, state, opts) + end + + defp copy_resp_headers(%Plug.Conn{} = conn, req) do + Enum.reduce(conn.resp_headers, req, fn {key, val}, acc -> + :cowboy_req.set_resp_header(key, val, acc) + end) + end + defp exit_on_error( :error, %Plug.Conn.WrapperError{kind: kind, reason: reason, stack: stack}, diff --git a/test/plug/cowboy/conn_test.exs b/test/plug/cowboy/conn_test.exs index 60a6dc1..8c8ea7f 100644 --- a/test/plug/cowboy/conn_test.exs +++ b/test/plug/cowboy/conn_test.exs @@ -350,6 +350,59 @@ defmodule Plug.Cowboy.ConnTest do assert body =~ "upgrade to unsupported not supported by Plug.Cowboy.Conn" end + defmodule NoopWebSocketHandler do + @behaviour :cowboy_websocket + + # We never actually call this; it's just here to quell compiler warnings + @impl true + def init(req, state), do: {:cowboy_websocket, req, state} + + @impl true + def websocket_handle(_frame, state), do: {:ok, state} + + @impl true + def websocket_info(_msg, state), do: {:ok, state} + end + + def upgrade_websocket(conn) do + # In actual use, it's the caller's responsibility to ensure the upgrade is valid before + # calling upgrade_adapter + conn + |> upgrade_adapter(:websocket, {NoopWebSocketHandler, [], []}) + end + + test "upgrades the connection when the connection is a valid websocket" do + {:ok, socket} = :gen_tcp.connect('localhost', 8003, active: false, mode: :binary) + + :gen_tcp.send(socket, """ + GET /upgrade_websocket HTTP/1.1\r + Host: server.example.com\r + Upgrade: websocket\r + Connection: Upgrade\r + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r + Sec-WebSocket-Version: 13\r + \r + """) + + {:ok, response} = :gen_tcp.recv(socket, 234) + + assert [ + "HTTP/1.1 101 Switching Protocols", + "cache-control: max-age=0, private, must-revalidate", + "connection: Upgrade", + "date: " <> _date, + "sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", + "server: Cowboy", + "upgrade: websocket", + "", + "" + ] = String.split(response, "\r\n") + end + + test "returns error in cases where an upgrade is indicated but the connection is not a valid upgrade" do + assert {426, _headers, ""} = request(:get, "/upgrade_websocket") + end + def push(conn) do conn |> push("/static/assets.css") diff --git a/test/plug/cowboy/websocket_handler_test.exs b/test/plug/cowboy/websocket_handler_test.exs new file mode 100644 index 0000000..3f9f53b --- /dev/null +++ b/test/plug/cowboy/websocket_handler_test.exs @@ -0,0 +1,167 @@ +defmodule WebSocketHandlerTest do + use ExUnit.Case, async: true + + defmodule WebSocketHandler do + @behaviour :cowboy_websocket + + # We never actually call this; it's just here to quell compiler warnings + @impl true + def init(req, state), do: {:cowboy_websocket, req, state} + + @impl true + def websocket_init(_opts), do: {:ok, :init} + + @impl true + def websocket_handle({:text, "state"}, state), do: {[{:text, inspect(state)}], state} + + def websocket_handle({:text, "whoami"}, state), + do: {[{:text, :erlang.pid_to_list(self())}], state} + + @impl true + def websocket_info(msg, state), do: {[{:text, inspect(msg)}], state} + end + + @protocol_options [ + idle_timeout: 1000, + request_timeout: 1000 + ] + + setup_all do + {:ok, _} = Plug.Cowboy.http(__MODULE__, [], port: 8083, protocol_options: @protocol_options) + on_exit(fn -> :ok = Plug.Cowboy.shutdown(__MODULE__.HTTP) end) + {:ok, port: 8083} + end + + @behaviour Plug + + @impl Plug + def init(arg), do: arg + + @impl Plug + def call(conn, _opts) do + conn = Plug.Conn.fetch_query_params(conn) + handler = conn.query_params["handler"] |> String.to_atom() + Plug.Conn.upgrade_adapter(conn, :websocket, {handler, [], [timeout: 1000]}) + end + + test "websocket_init and websocket_handle are called", context do + client = tcp_client(context) + http1_handshake(client, WebSocketHandler) + + send_text_frame(client, "state") + {:ok, result} = recv_text_frame(client) + assert result == inspect(:init) + end + + test "websocket_info is called", context do + client = tcp_client(context) + http1_handshake(client, WebSocketHandler) + + send_text_frame(client, "whoami") + {:ok, pid} = recv_text_frame(client) + pid = pid |> String.to_charlist() |> :erlang.list_to_pid() + + Process.send(pid, "hello info", []) + + {:ok, response} = recv_text_frame(client) + assert response == inspect("hello info") + end + + # Simple WebSocket client + + def tcp_client(context) do + {:ok, socket} = :gen_tcp.connect('localhost', context[:port], active: false, mode: :binary) + + socket + end + + def http1_handshake(client, module, params \\ []) do + params = params |> Keyword.put(:handler, module) + + :gen_tcp.send(client, """ + GET /?#{URI.encode_query(params)} HTTP/1.1\r + Host: server.example.com\r + Upgrade: websocket\r + Connection: Upgrade\r + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r + Sec-WebSocket-Version: 13\r + \r + """) + + {:ok, response} = :gen_tcp.recv(client, 234) + + [ + "HTTP/1.1 101 Switching Protocols", + "cache-control: max-age=0, private, must-revalidate", + "connection: Upgrade", + "date: " <> _date, + "sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", + "server: Cowboy", + "upgrade: websocket", + "", + "" + ] = String.split(response, "\r\n") + end + + defp recv_text_frame(client) do + {:ok, 0x8, 0x1, body} = recv_frame(client) + {:ok, body} + end + + defp recv_frame(client) do + {:ok, header} = :gen_tcp.recv(client, 2) + <> = header + + {:ok, data} = + case length do + 0 -> + {:ok, <<>>} + + 126 -> + {:ok, <>} = :gen_tcp.recv(client, 2) + :gen_tcp.recv(client, length) + + 127 -> + {:ok, <>} = :gen_tcp.recv(client, 8) + :gen_tcp.recv(client, length) + + length -> + :gen_tcp.recv(client, length) + end + + {:ok, flags, opcode, data} + end + + defp send_text_frame(client, data, flags \\ 0x8) do + send_frame(client, flags, 0x1, data) + end + + defp send_frame(client, flags, opcode, data) do + mask = :rand.uniform(1_000_000) + masked_data = mask(data, mask) + + mask_flag_and_size = + case byte_size(masked_data) do + size when size <= 125 -> <<1::1, size::7>> + size when size <= 65_535 -> <<1::1, 126::7, size::16>> + size -> <<1::1, 127::7, size::64>> + end + + :gen_tcp.send(client, [<>, mask_flag_and_size, <>, masked_data]) + end + + # Note that masking is an involution, so we don't need a separate unmask function + defp mask(payload, mask, acc \\ <<>>) + + defp mask(payload, mask, acc) when is_integer(mask), do: mask(payload, <>, acc) + + defp mask(<>, <>, acc) do + mask(rest, mask, acc <> <>) + end + + defp mask(<>, <>, acc) do + mask(rest, <>, acc <> <>) + end + + defp mask(<<>>, _mask, acc), do: acc +end