diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 19901e4..517fcb2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ jobs: matrix: include: - pair: - elixir: 1.7.4 + elixir: 1.10.4 otp: 22.3.4 - pair: elixir: 1.11.4 diff --git a/lib/plug/cowboy.ex b/lib/plug/cowboy.ex index 0728ab9..bc91be8 100644 --- a/lib/plug/cowboy.ex +++ b/lib/plug/cowboy.ex @@ -125,6 +125,29 @@ defmodule Plug.Cowboy do To opt-out of this default instrumentation, you can manually configure cowboy with the option `stream_handlers: [:cowboy_stream_h]`. + + ## WebSocket support + + Plug.Cowboy supports upgrading HTTP requests to WebSocket connections via + the use of the `Plug.Conn.upgrade_adapter/3` function, called with `:websocket` as the second + argument. Applications should validate that the connection represents a valid WebSocket request + before calling this function (Cowboy will validate the connection as part of the upgrade + process, but does not provide any capacity for an application to be notified if the upgrade is + not successful). If an application wishes to negotiate WebSocket subprotocols or otherwise set + any response headers, it should do so before calling `Plug.Conn.upgrade_adapter/3`. + + The third argument to `Plug.Conn.upgrade_adapter/3` defines the details of how Plug.Cowboy + should handle the WebSocket connection, and must take the form `{handler, handler_opts, + connection_opts}`, where values are as follows: + + * `handler` is a module which implements the + [`:cowboy_websocket`](https://ninenines.eu/docs/en/cowboy/2.6/manual/cowboy_websocket/) + behaviour. Note that this module will NOT have its `c:cowboy_websocket.init/2` callback + called; only the 'later' parts of the `:cowboy_websocket` lifecycle are supported + * `handler_opts` is an arbitrary term which will be passed as the argument to + `c:cowboy_websocket.websocket_init/1` + * `connection_opts` is a map with any of [Cowboy's websockets options](https://ninenines.eu/docs/en/cowboy/2.6/manual/cowboy_websocket/#_opts) + """ require Logger diff --git a/lib/plug/cowboy/conn.ex b/lib/plug/cowboy/conn.ex index 732584a..4385406 100644 --- a/lib/plug/cowboy/conn.ex +++ b/lib/plug/cowboy/conn.ex @@ -90,6 +90,23 @@ defmodule Plug.Cowboy.Conn do :cowboy_req.inform(status, to_headers_map(headers), req) end + @impl true + def upgrade(req, :websocket, args) do + case args do + {handler, _state, cowboy_opts} when is_atom(handler) and is_map(cowboy_opts) -> + :ok + + _ -> + raise ArgumentError, + "expected websocket upgrade on Cowboy to be on the format {handler :: atom(), arg :: term(), opts :: map()}, got: " <> + inspect(args) + end + + {:ok, Map.put(req, :upgrade, {:websocket, args})} + end + + def upgrade(_req, _protocol, _args), do: {:error, :not_supported} + @impl true def push(req, path, headers) do opts = diff --git a/lib/plug/cowboy/handler.ex b/lib/plug/cowboy/handler.ex index e7aed97..2dd1061 100644 --- a/lib/plug/cowboy/handler.ex +++ b/lib/plug/cowboy/handler.ex @@ -7,12 +7,17 @@ defmodule Plug.Cowboy.Handler do conn = @connection.conn(req) try do - %{adapter: {@connection, req}} = - conn - |> plug.call(opts) - |> maybe_send(plug) + conn + |> plug.call(opts) + |> maybe_send(plug) + |> case do + %Plug.Conn{adapter: {@connection, %{upgrade: {:websocket, websocket_args}} = req}} = conn -> + {handler, state, cowboy_opts} = websocket_args + {__MODULE__, copy_resp_headers(conn, req), {handler, state}, cowboy_opts} - {:ok, req, {plug, opts}} + %Plug.Conn{adapter: {@connection, req}} -> + {:ok, req, {plug, opts}} + end catch kind, reason -> exit_on_error(kind, reason, __STACKTRACE__, {plug, :call, [conn, opts]}) @@ -25,6 +30,16 @@ defmodule Plug.Cowboy.Handler do end end + def upgrade(req, env, __MODULE__, {handler, state}, opts) do + :cowboy_websocket.upgrade(req, env, handler, state, opts) + end + + defp copy_resp_headers(%Plug.Conn{} = conn, req) do + Enum.reduce(conn.resp_headers, req, fn {key, val}, acc -> + :cowboy_req.set_resp_header(key, val, acc) + end) + end + defp exit_on_error( :error, %Plug.Conn.WrapperError{kind: kind, reason: reason, stack: stack}, diff --git a/mix.exs b/mix.exs index e6dfb1c..e78a2a3 100644 --- a/mix.exs +++ b/mix.exs @@ -9,7 +9,7 @@ defmodule Plug.Cowboy.MixProject do [ app: :plug_cowboy, version: @version, - elixir: "~> 1.7", + elixir: "~> 1.10", deps: deps(), package: package(), description: @description, @@ -33,7 +33,7 @@ defmodule Plug.Cowboy.MixProject do def deps do [ - {:plug, "~> 1.7"}, + {:plug, "~> 1.14"}, {:cowboy, "~> 2.7"}, {:cowboy_telemetry, "~> 0.3"}, {:ex_doc, "~> 0.20", only: :docs}, diff --git a/mix.lock b/mix.lock index 5dbb2e2..8950cc4 100644 --- a/mix.lock +++ b/mix.lock @@ -11,12 +11,12 @@ "makeup": {:hex, :makeup, "1.0.5", "d5a830bc42c9800ce07dd97fa94669dfb93d3bf5fcf6ea7a0c67b2e0e4a7f26c", [:mix], [{:nimble_parsec, "~> 0.5 or ~> 1.0", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cfa158c02d3f5c0c665d0af11512fed3fba0144cf1aadee0f2ce17747fba2ca9"}, "makeup_elixir": {:hex, :makeup_elixir, "0.15.1", "b5888c880d17d1cc3e598f05cdb5b5a91b7b17ac4eaf5f297cb697663a1094dd", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.1", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "db68c173234b07ab2a07f645a5acdc117b9f99d69ebf521821d89690ae6c6ec8"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"}, - "mime": {:hex, :mime, "1.5.0", "203ef35ef3389aae6d361918bf3f952fa17a09e8e43b5aa592b93eba05d0fb8d", [:mix], [], "hexpm", "55a94c0f552249fc1a3dd9cd2d3ab9de9d3c89b559c2bd01121f824834f24746"}, + "mime": {:hex, :mime, "2.0.3", "3676436d3d1f7b81b5a2d2bd8405f412c677558c81b1c92be58c00562bb59095", [:mix], [], "hexpm", "27a30bf0db44d25eecba73755acf4068cbfe26a4372f9eb3e4ea3a45956bff6b"}, "nimble_parsec": {:hex, :nimble_parsec, "1.1.0", "3a6fca1550363552e54c216debb6a9e95bd8d32348938e13de5eda962c0d7f89", [:mix], [], "hexpm", "08eb32d66b706e913ff748f11694b17981c0b04a33ef470e33e11b3d3ac8f54b"}, - "plug": {:hex, :plug, "1.11.0", "f17217525597628298998bc3baed9f8ea1fa3f1160aa9871aee6df47a6e4d38e", [:mix], [{:mime, "~> 1.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "2d9c633f0499f9dc5c2fd069161af4e2e7756890b81adcbb2ceaa074e8308876"}, - "plug_crypto": {:hex, :plug_crypto, "1.2.0", "1cb20793aa63a6c619dd18bb33d7a3aa94818e5fd39ad357051a67f26dfa2df6", [:mix], [], "hexpm", "a48b538ae8bf381ffac344520755f3007cc10bd8e90b240af98ea29b69683fc2"}, + "plug": {:hex, :plug, "1.14.0", "ba4f558468f69cbd9f6b356d25443d0b796fbdc887e03fa89001384a9cac638f", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.3 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "bf020432c7d4feb7b3af16a0c2701455cbbbb95e5b6866132cb09eb0c29adc14"}, + "plug_crypto": {:hex, :plug_crypto, "1.2.3", "8f77d13aeb32bfd9e654cb68f0af517b371fb34c56c9f2b58fe3df1235c1251a", [:mix], [], "hexpm", "b5672099c6ad5c202c45f5a403f21a3411247f164e4a8fab056e5cd8a290f4a2"}, "ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"}, "ssl_verify_hostname": {:hex, :ssl_verify_hostname, "1.0.6", "45866d958d9ae51cfe8fef0050ab8054d25cba23ace43b88046092aa2c714645", [:make], [], "hexpm", "72b2fc8a8e23d77eed4441137fefa491bbf4a6dc52e9c0045f3f8e92e66243b5"}, - "telemetry": {:hex, :telemetry, "0.4.2", "2808c992455e08d6177322f14d3bdb6b625fbcfd233a73505870d8738a2f4599", [:rebar3], [], "hexpm", "2d1419bd9dda6a206d7b5852179511722e2b18812310d304620c7bd92a13fcef"}, + "telemetry": {:hex, :telemetry, "0.4.3", "a06428a514bdbc63293cd9a6263aad00ddeb66f608163bdec7c8995784080818", [:rebar3], [], "hexpm", "eb72b8365ffda5bed68a620d1da88525e326cb82a75ee61354fc24b844768041"}, "x509": {:hex, :x509, "0.6.0", "51a274a8368cf6fe771c1920ce5b075d28a076745987666593adb818aadf1bf7", [:mix], [], "hexpm", "111d29a5388e059413aa4553b64f041452fe49777be310622cd71e451d29630d"}, } diff --git a/test/plug/cowboy/conn_test.exs b/test/plug/cowboy/conn_test.exs index cc1e44a..a5af995 100644 --- a/test/plug/cowboy/conn_test.exs +++ b/test/plug/cowboy/conn_test.exs @@ -340,6 +340,69 @@ defmodule Plug.Cowboy.ConnTest do request(:get, "/inform") end + def upgrade_unsupported(conn) do + conn + |> upgrade_adapter(:unsupported, opt: :unsupported) + end + + test "upgrade will not set the response" do + assert {500, _, body} = request(:get, "/upgrade_unsupported") + assert body =~ "upgrade to unsupported not supported by Plug.Cowboy.Conn" + end + + defmodule NoopWebSocketHandler do + @behaviour :cowboy_websocket + + # We never actually call this; it's just here to quell compiler warnings + @impl true + def init(req, state), do: {:cowboy_websocket, req, state} + + @impl true + def websocket_handle(_frame, state), do: {:ok, state} + + @impl true + def websocket_info(_msg, state), do: {:ok, state} + end + + def upgrade_websocket(conn) do + # In actual use, it's the caller's responsibility to ensure the upgrade is valid before + # calling upgrade_adapter + conn + |> upgrade_adapter(:websocket, {NoopWebSocketHandler, [], %{}}) + end + + test "upgrades the connection when the connection is a valid websocket" do + {:ok, socket} = :gen_tcp.connect('localhost', 8003, active: false, mode: :binary) + + :gen_tcp.send(socket, """ + GET /upgrade_websocket HTTP/1.1\r + Host: server.example.com\r + Upgrade: websocket\r + Connection: Upgrade\r + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r + Sec-WebSocket-Version: 13\r + \r + """) + + {:ok, response} = :gen_tcp.recv(socket, 234) + + assert [ + "HTTP/1.1 101 Switching Protocols", + "cache-control: max-age=0, private, must-revalidate", + "connection: Upgrade", + "date: " <> _date, + "sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", + "server: Cowboy", + "upgrade: websocket", + "", + "" + ] = String.split(response, "\r\n") + end + + test "returns error in cases where an upgrade is indicated but the connection is not a valid upgrade" do + assert {426, _headers, ""} = request(:get, "/upgrade_websocket") + end + def push(conn) do conn |> push("/static/assets.css") diff --git a/test/plug/cowboy/websocket_handler_test.exs b/test/plug/cowboy/websocket_handler_test.exs new file mode 100644 index 0000000..ade399a --- /dev/null +++ b/test/plug/cowboy/websocket_handler_test.exs @@ -0,0 +1,167 @@ +defmodule WebSocketHandlerTest do + use ExUnit.Case, async: true + + defmodule WebSocketHandler do + @behaviour :cowboy_websocket + + # We never actually call this; it's just here to quell compiler warnings + @impl true + def init(req, state), do: {:cowboy_websocket, req, state} + + @impl true + def websocket_init(_opts), do: {:ok, :init} + + @impl true + def websocket_handle({:text, "state"}, state), do: {[{:text, inspect(state)}], state} + + def websocket_handle({:text, "whoami"}, state), + do: {[{:text, :erlang.pid_to_list(self())}], state} + + @impl true + def websocket_info(msg, state), do: {[{:text, inspect(msg)}], state} + end + + @protocol_options [ + idle_timeout: 1000, + request_timeout: 1000 + ] + + setup_all do + {:ok, _} = Plug.Cowboy.http(__MODULE__, [], port: 8083, protocol_options: @protocol_options) + on_exit(fn -> :ok = Plug.Cowboy.shutdown(__MODULE__.HTTP) end) + {:ok, port: 8083} + end + + @behaviour Plug + + @impl Plug + def init(arg), do: arg + + @impl Plug + def call(conn, _opts) do + conn = Plug.Conn.fetch_query_params(conn) + handler = conn.query_params["handler"] |> String.to_atom() + Plug.Conn.upgrade_adapter(conn, :websocket, {handler, [], %{idle_timeout: 1000}}) + end + + test "websocket_init and websocket_handle are called", context do + client = tcp_client(context) + http1_handshake(client, WebSocketHandler) + + send_text_frame(client, "state") + {:ok, result} = recv_text_frame(client) + assert result == inspect(:init) + end + + test "websocket_info is called", context do + client = tcp_client(context) + http1_handshake(client, WebSocketHandler) + + send_text_frame(client, "whoami") + {:ok, pid} = recv_text_frame(client) + pid = pid |> String.to_charlist() |> :erlang.list_to_pid() + + Process.send(pid, "hello info", []) + + {:ok, response} = recv_text_frame(client) + assert response == inspect("hello info") + end + + # Simple WebSocket client + + def tcp_client(context) do + {:ok, socket} = :gen_tcp.connect('localhost', context[:port], active: false, mode: :binary) + + socket + end + + def http1_handshake(client, module, params \\ []) do + params = params |> Keyword.put(:handler, module) + + :gen_tcp.send(client, """ + GET /?#{URI.encode_query(params)} HTTP/1.1\r + Host: server.example.com\r + Upgrade: websocket\r + Connection: Upgrade\r + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r + Sec-WebSocket-Version: 13\r + \r + """) + + {:ok, response} = :gen_tcp.recv(client, 234) + + [ + "HTTP/1.1 101 Switching Protocols", + "cache-control: max-age=0, private, must-revalidate", + "connection: Upgrade", + "date: " <> _date, + "sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", + "server: Cowboy", + "upgrade: websocket", + "", + "" + ] = String.split(response, "\r\n") + end + + defp recv_text_frame(client) do + {:ok, 0x8, 0x1, body} = recv_frame(client) + {:ok, body} + end + + defp recv_frame(client) do + {:ok, header} = :gen_tcp.recv(client, 2) + <> = header + + {:ok, data} = + case length do + 0 -> + {:ok, <<>>} + + 126 -> + {:ok, <>} = :gen_tcp.recv(client, 2) + :gen_tcp.recv(client, length) + + 127 -> + {:ok, <>} = :gen_tcp.recv(client, 8) + :gen_tcp.recv(client, length) + + length -> + :gen_tcp.recv(client, length) + end + + {:ok, flags, opcode, data} + end + + defp send_text_frame(client, data, flags \\ 0x8) do + send_frame(client, flags, 0x1, data) + end + + defp send_frame(client, flags, opcode, data) do + mask = :rand.uniform(1_000_000) + masked_data = mask(data, mask) + + mask_flag_and_size = + case byte_size(masked_data) do + size when size <= 125 -> <<1::1, size::7>> + size when size <= 65_535 -> <<1::1, 126::7, size::16>> + size -> <<1::1, 127::7, size::64>> + end + + :gen_tcp.send(client, [<>, mask_flag_and_size, <>, masked_data]) + end + + # Note that masking is an involution, so we don't need a separate unmask function + defp mask(payload, mask, acc \\ <<>>) + + defp mask(payload, mask, acc) when is_integer(mask), do: mask(payload, <>, acc) + + defp mask(<>, <>, acc) do + mask(rest, mask, acc <> <>) + end + + defp mask(<>, <>, acc) do + mask(rest, <>, acc <> <>) + end + + defp mask(<<>>, _mask, acc), do: acc +end