diff --git a/lib/plug/cowboy.ex b/lib/plug/cowboy.ex index 0728ab9..d2d90d8 100644 --- a/lib/plug/cowboy.ex +++ b/lib/plug/cowboy.ex @@ -125,6 +125,30 @@ defmodule Plug.Cowboy do To opt-out of this default instrumentation, you can manually configure cowboy with the option `stream_handlers: [:cowboy_stream_h]`. + + ## WebSocket support + + Plug.Cowboy supports upgrading HTTP requests to WebSocket connections via + the use of the `Plug.Conn.upgrade_adapter/3` function, called with `:websocket` as the second + argument. Applications should validate that the connection represents a valid WebSocket request + before calling this function (Plug.Cowboy will validate the connection as part of the upgrade + process, but does not provide any capacity for an application to be notified if the upgrade is + not successful). If an application wishes to negotiate WebSocket subprotocols or otherwise set + any response headers, it should do so before calling `Plug.Conn.upgrade_adapter/3`. + + The third argument to `Plug.Conn.upgrade_adapter/3` defines the details of how Plug.Cowboy + should handle the WebSocket connection, and must take the form `{handler, handler_opts, + connection_opts}`, where values are as follows: + + * `handler` is a module which implements the `Sock` API + * `handler_opts` is an arbitrary term which will be passed as the argument to `c:Sock.init/1` + * `connection_opts` is a keyword list which consists of zero or more of the following options: + * `timeout`: The number of milliseconds to wait after no client data is received before + closing the connection. Defaults to `60_000` + * `compress`: Whether or not to attempt negotiation of a compression extension with the + client. Defaults to `false` + * `max_frame_size`: The maximum frame size in bytes to accept from the client. The connection + will be closed if the client sends a frame larger than this suze. Defaults to `:infinity` """ require Logger diff --git a/lib/plug/cowboy/conn.ex b/lib/plug/cowboy/conn.ex index b8b8544..4982606 100644 --- a/lib/plug/cowboy/conn.ex +++ b/lib/plug/cowboy/conn.ex @@ -91,9 +91,9 @@ defmodule Plug.Cowboy.Conn do end @impl true - def upgrade(_req, _protocol, _opts) do - {:error, :not_supported} - end + def upgrade(req, :websocket, opts), do: {:ok, req |> Map.put(:upgrade, {:websocket, opts})} + + def upgrade(_req, _upgrade, _opts), do: {:error, :not_supported} @impl true def push(req, path, headers) do diff --git a/lib/plug/cowboy/handler.ex b/lib/plug/cowboy/handler.ex index e7aed97..196b836 100644 --- a/lib/plug/cowboy/handler.ex +++ b/lib/plug/cowboy/handler.ex @@ -1,5 +1,11 @@ defmodule Plug.Cowboy.Handler do @moduledoc false + + if Code.ensure_loaded?(:cowboy_websocket) and + function_exported?(:cowboy_websocket, :behaviour_info, 1) do + @behaviour :cowboy_websocket + end + @connection Plug.Cowboy.Conn @already_sent {:plug_conn, :sent} @@ -7,12 +13,30 @@ defmodule Plug.Cowboy.Handler do conn = @connection.conn(req) try do - %{adapter: {@connection, req}} = - conn - |> plug.call(opts) - |> maybe_send(plug) + conn + |> plug.call(opts) + |> maybe_send(plug) + |> case do + %Plug.Conn{adapter: {@connection, %{upgrade: {:websocket, websocket_opts}} = req}} = conn -> + {handler, state, connection_opts} = websocket_opts - {:ok, req, {plug, opts}} + cowboy_opts = + connection_opts + |> Enum.flat_map(fn + {:timeout, timeout} -> [idle_timeout: timeout] + {:compress, _} = opt -> [opt] + {:max_frame_size, _} = opt -> [opt] + _other -> [] + end) + |> Map.new() + + handler_opts = Keyword.take(connection_opts, [:fullsweep_after]) + triplet = {handler, handler_opts, state} + {:cowboy_websocket, copy_resp_headers(conn, req), triplet, cowboy_opts} + + %Plug.Conn{adapter: {@connection, req}} -> + {:ok, req, {plug, opts}} + end catch kind, reason -> exit_on_error(kind, reason, __STACKTRACE__, {plug, :call, [conn, opts]}) @@ -25,6 +49,12 @@ defmodule Plug.Cowboy.Handler do end end + defp copy_resp_headers(%Plug.Conn{} = conn, req) do + Enum.reduce(conn.resp_headers, req, fn {key, val}, acc -> + :cowboy_req.set_resp_header(key, val, acc) + end) + end + defp exit_on_error( :error, %Plug.Conn.WrapperError{kind: kind, reason: reason, stack: stack}, @@ -55,4 +85,84 @@ defmodule Plug.Cowboy.Handler do raise "Cowboy2 adapter expected #{inspect(plug)} to return Plug.Conn but got: " <> inspect(other) end + + ## Websocket callbacks + + def websocket_init({handler, process_flags, state}) do + for {key, value} <- process_flags do + :erlang.process_flag(key, value) + end + + handle_reply(handler, handler.init(state)) + end + + def websocket_handle({opcode, payload}, {:sock, handler, state}) + when opcode in [:text, :binary] do + handle_reply(handler, handler.handle_in({payload, opcode: opcode}, state)) + end + + def websocket_handle({opcode, payload}, handler_state) when opcode in [:ping, :pong] do + handle_control_frame({payload, opcode: opcode}, handler_state) + end + + def websocket_handle(opcode, handler_state) when opcode in [:ping, :pong] do + handle_control_frame({nil, opcode: opcode}, handler_state) + end + + def websocket_handle(_other, handler_state) do + {:ok, handler_state} + end + + def websocket_info(message, {:sock, handler, state}) do + handle_reply(handler, handler.handle_info(message, state)) + end + + defp handle_reply(handler, {:ok, state}), do: {:ok, {:sock, handler, state}} + defp handle_reply(handler, {:push, data, state}), do: {:reply, data, {:sock, handler, state}} + + defp handle_reply(handler, {:reply, _status, data, state}), + do: {:reply, data, {:sock, handler, state}} + + defp handle_reply(handler, {:stop, _reason, state}), do: {:stop, {:sock, handler, state}} + + defp handle_control_frame(payload_with_opts, {:sock, handler, state}) do + reply = + if function_exported?(handler, :handle_control, 2) do + handler.handle_control(payload_with_opts, state) + else + {:ok, state} + end + + handle_reply(handler, reply) + end + + def terminate({:remote, :closed}, _req, {:sock, handler, state}) do + handler.terminate(:closed, state) + end + + def terminate({:remote, code, _}, _req, {:sock, handler, state}) + when code in 1000..1003 or code in 1005..1011 or code == 1015 do + handler.terminate(:remote, state) + end + + def terminate(:remote, _req, {:sock, handler, state}) do + handler.terminate(:remote, state) + end + + def terminate({:error, reason}, _req, {:sock, handler, state}) do + handler.terminate({:error, reason}, state) + end + + def terminate(:stop, _req, {:sock, handler, state}) do + handler.terminate(:normal, state) + end + + def terminate(reason, _req, {:sock, handler, state}) do + handler.terminate(reason, state) + end + + # Note that this terminate/3 function is part of the Cowboy Handler API, not cowboy_websocket + def terminate(_reason, _req, _state) do + :ok + end end diff --git a/mix.exs b/mix.exs index e6dfb1c..af25247 100644 --- a/mix.exs +++ b/mix.exs @@ -34,6 +34,7 @@ defmodule Plug.Cowboy.MixProject do def deps do [ {:plug, "~> 1.7"}, + {:sock, "~> 0.3.0"}, {:cowboy, "~> 2.7"}, {:cowboy_telemetry, "~> 0.3"}, {:ex_doc, "~> 0.20", only: :docs}, diff --git a/mix.lock b/mix.lock index 5dbb2e2..e87f5da 100644 --- a/mix.lock +++ b/mix.lock @@ -16,6 +16,7 @@ "plug": {:hex, :plug, "1.11.0", "f17217525597628298998bc3baed9f8ea1fa3f1160aa9871aee6df47a6e4d38e", [:mix], [{:mime, "~> 1.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_crypto, "~> 1.1.1 or ~> 1.2", [hex: :plug_crypto, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "2d9c633f0499f9dc5c2fd069161af4e2e7756890b81adcbb2ceaa074e8308876"}, "plug_crypto": {:hex, :plug_crypto, "1.2.0", "1cb20793aa63a6c619dd18bb33d7a3aa94818e5fd39ad357051a67f26dfa2df6", [:mix], [], "hexpm", "a48b538ae8bf381ffac344520755f3007cc10bd8e90b240af98ea29b69683fc2"}, "ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"}, + "sock": {:hex, :sock, "0.3.0", "026496b408061a7ceea7486ed6d99439b74a82f26bbdd44bdbc78a9a80cfbe7b", [:mix], [], "hexpm", "27432ee0598f16068fba49c1436a39932c984e9fdbd4da870c0b2f9a1bb7d453"}, "ssl_verify_hostname": {:hex, :ssl_verify_hostname, "1.0.6", "45866d958d9ae51cfe8fef0050ab8054d25cba23ace43b88046092aa2c714645", [:make], [], "hexpm", "72b2fc8a8e23d77eed4441137fefa491bbf4a6dc52e9c0045f3f8e92e66243b5"}, "telemetry": {:hex, :telemetry, "0.4.2", "2808c992455e08d6177322f14d3bdb6b625fbcfd233a73505870d8738a2f4599", [:rebar3], [], "hexpm", "2d1419bd9dda6a206d7b5852179511722e2b18812310d304620c7bd92a13fcef"}, "x509": {:hex, :x509, "0.6.0", "51a274a8368cf6fe771c1920ce5b075d28a076745987666593adb818aadf1bf7", [:mix], [], "hexpm", "111d29a5388e059413aa4553b64f041452fe49777be310622cd71e451d29630d"}, diff --git a/test/plug/cowboy/conn_test.exs b/test/plug/cowboy/conn_test.exs index 3e2bbfa..7848411 100644 --- a/test/plug/cowboy/conn_test.exs +++ b/test/plug/cowboy/conn_test.exs @@ -340,14 +340,69 @@ defmodule Plug.Cowboy.ConnTest do request(:get, "/inform") end - def upgrade(conn) do + def upgrade_unsupported(conn) do conn - |> upgrade_adapter(:unsupported, opt: :unsupported) - |> send_resp(200, "upgrade") + |> upgrade_adapter(:unsupported, nil) + |> send_resp(200, "Not supported") end - test "upgrade will not set the response" do - assert {200, _headers, "upgrade"} = request(:get, "/upgrade") + test "does not update the conn or send any data on unsupported upgrades" do + assert {200, _headers, "Not supported"} = request(:get, "/upgrade_unsupported") + end + + defmodule NoopSock do + @behaviour Sock + + @impl true + def init(arg), do: {:ok, arg} + + @impl true + def handle_in(_data, state), do: {:ok, state} + + @impl true + def handle_info(_msg, state), do: {:ok, state} + + @impl true + def terminate(_reason, _state), do: :ok + end + + def upgrade_websocket(conn) do + # In actual use, it's the caller's responsibility to ensure the upgrade is valid before + # calling upgrade_adapter + conn + |> upgrade_adapter(:websocket, {NoopSock, [], []}) + end + + test "returns error in cases where an upgrade is indicated but the connection is not a valid upgrade" do + assert {426, _headers, ""} = request(:get, "/upgrade_websocket") + end + + test "upgrades the connection when the connection is a valid websocket" do + {:ok, socket} = :gen_tcp.connect('localhost', 8003, active: false, mode: :binary) + + :gen_tcp.send(socket, """ + GET /upgrade_websocket HTTP/1.1\r + Host: server.example.com\r + Upgrade: websocket\r + Connection: Upgrade\r + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r + Sec-WebSocket-Version: 13\r + \r + """) + + {:ok, response} = :gen_tcp.recv(socket, 234) + + assert [ + "HTTP/1.1 101 Switching Protocols", + "cache-control: max-age=0, private, must-revalidate", + "connection: Upgrade", + "date: " <> _date, + "sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", + "server: Cowboy", + "upgrade: websocket", + "", + "" + ] = String.split(response, "\r\n") end def push(conn) do diff --git a/test/plug/cowboy/sock_test.exs b/test/plug/cowboy/sock_test.exs new file mode 100644 index 0000000..bf1bd26 --- /dev/null +++ b/test/plug/cowboy/sock_test.exs @@ -0,0 +1,827 @@ +defmodule WebSocketSockTest do + use ExUnit.Case, async: true + + defmodule NoopSock do + defmacro __using__(_) do + quote do + @behaviour Sock + + @impl true + def init(arg), do: {:ok, arg} + + @impl true + def handle_in(_data, state), do: {:ok, state} + + @impl true + def handle_info(_msg, state), do: {:ok, state} + + @impl true + def terminate(_reason, _state), do: :ok + + defoverridable init: 1, handle_in: 2, handle_info: 2, terminate: 2 + end + end + end + + @protocol_options [ + idle_timeout: 1000, + request_timeout: 1000 + ] + + setup_all do + {:ok, _} = Plug.Cowboy.http(__MODULE__, [], port: 8083, protocol_options: @protocol_options) + + on_exit(fn -> + :ok = Plug.Cowboy.shutdown(__MODULE__.HTTP) + end) + + {:ok, port: 8083} + end + + @behaviour Plug + + @impl Plug + def init(arg), do: arg + + @impl Plug + def call(conn, _opts) do + conn = Plug.Conn.fetch_query_params(conn) + sock = conn.query_params["sock"] |> String.to_atom() + Plug.Conn.upgrade_adapter(conn, :websocket, {sock, [], [timeout: 1000]}) + end + + describe "init" do + defmodule InitOKStateSock do + use NoopSock + def init(_opts), do: {:ok, :init} + def handle_in(_data, state), do: {:push, {:text, inspect(state)}, state} + end + + test "can return an ok tuple and update state", context do + client = tcp_client(context) + http1_handshake(client, InitOKStateSock) + + send_text_frame(client, "OK") + {:ok, result} = recv_text_frame(client) + assert result == inspect(:init) + end + + defmodule InitPushStateSock do + use NoopSock + def init(_opts), do: {:push, {:text, "init"}, :init} + def handle_in(_data, state), do: {:push, {:text, inspect(state)}, state} + end + + test "can return a push tuple and update state", context do + client = tcp_client(context) + http1_handshake(client, InitPushStateSock) + + # Ignore the frame it pushes us + _ = recv_text_frame(client) + + send_text_frame(client, "OK") + {:ok, response} = recv_text_frame(client) + assert response == inspect(:init) + end + + defmodule InitReplyStateSock do + use NoopSock + def init(_opts), do: {:reply, :ok, {:text, "init"}, :init} + def handle_in(_data, state), do: {:push, {:text, inspect(state)}, state} + end + + test "can return a reply tuple and update state", context do + client = tcp_client(context) + http1_handshake(client, InitReplyStateSock) + + # Ignore the frame it pushes us + _ = recv_text_frame(client) + + send_text_frame(client, "OK") + {:ok, response} = recv_text_frame(client) + assert response == inspect(:init) + end + + defmodule InitTextSock do + use NoopSock + def init(_opts), do: {:push, {:text, "TEXT"}, :init} + end + + test "can return a text frame", context do + client = tcp_client(context) + http1_handshake(client, InitTextSock) + + assert recv_text_frame(client) == {:ok, "TEXT"} + end + + defmodule InitBinarySock do + use NoopSock + def init(_opts), do: {:push, {:binary, "BINARY"}, :init} + end + + test "can return a binary frame", context do + client = tcp_client(context) + http1_handshake(client, InitBinarySock) + + assert recv_binary_frame(client) == {:ok, "BINARY"} + end + + defmodule InitPingSock do + use NoopSock + def init(_opts), do: {:push, {:ping, "PING"}, :init} + end + + test "can return a ping frame", context do + client = tcp_client(context) + http1_handshake(client, InitPingSock) + + assert recv_ping_frame(client) == {:ok, "PING"} + end + + defmodule InitPongSock do + use NoopSock + def init(_opts), do: {:push, {:pong, "PONG"}, :init} + end + + test "can return a pong frame", context do + client = tcp_client(context) + http1_handshake(client, InitPongSock) + + assert recv_pong_frame(client) == {:ok, "PONG"} + end + + defmodule InitCloseSock do + use NoopSock + def init(_opts), do: {:stop, :normal, :init} + end + + test "can close a connection by returning a stop tuple", context do + client = tcp_client(context) + http1_handshake(client, InitCloseSock) + + assert recv_connection_close_frame(client) == {:ok, <<1000::16>>} + assert connection_closed_for_reading?(client) + end + end + + describe "handle_in" do + defmodule HandleInEchoSock do + use NoopSock + def handle_in({data, opcode: opcode}, state), do: {:push, {opcode, data}, state} + end + + test "can receive a text frame", context do + client = tcp_client(context) + http1_handshake(client, HandleInEchoSock) + + send_text_frame(client, "OK") + + assert recv_text_frame(client) == {:ok, "OK"} + end + + test "can receive a bianry frame", context do + client = tcp_client(context) + http1_handshake(client, HandleInEchoSock) + + send_binary_frame(client, "OK") + + assert recv_binary_frame(client) == {:ok, "OK"} + end + + defmodule HandleInStateSock do + use NoopSock + def init(_opts), do: {:ok, []} + + def handle_in({"dump", opcode: :text} = data, state), + do: {:push, {:text, inspect(state)}, [data | state]} + + def handle_in(data, state), do: {:ok, [data | state]} + end + + test "can return an ok tuple and update state", context do + client = tcp_client(context) + http1_handshake(client, HandleInStateSock) + + send_text_frame(client, "OK") + send_text_frame(client, "dump") + + {:ok, response} = recv_text_frame(client) + assert response == inspect([{"OK", opcode: :text}]) + end + + test "can return a push tuple and update state", context do + client = tcp_client(context) + http1_handshake(client, HandleInStateSock) + + send_text_frame(client, "dump") + _ = recv_text_frame(client) + send_text_frame(client, "dump") + + {:ok, response} = recv_text_frame(client) + assert response == inspect([{"dump", opcode: :text}]) + end + + defmodule HandleInReplyStateSock do + use NoopSock + def init(_opts), do: {:ok, []} + + def handle_in({"dump", opcode: :text} = data, state), + do: {:reply, :ok, {:text, inspect(state)}, [data | state]} + + def handle_in(data, state), do: {:ok, [data | state]} + end + + test "can return a reply tuple and update state", context do + client = tcp_client(context) + http1_handshake(client, HandleInReplyStateSock) + + send_text_frame(client, "dump") + _ = recv_text_frame(client) + send_text_frame(client, "dump") + + {:ok, response} = recv_text_frame(client) + assert response == inspect([{"dump", opcode: :text}]) + end + + defmodule HandleInTextSock do + use NoopSock + def handle_in(_data, state), do: {:push, {:text, "TEXT"}, state} + end + + test "can return a text frame", context do + client = tcp_client(context) + http1_handshake(client, HandleInTextSock) + + send_text_frame(client, "OK") + + assert recv_text_frame(client) == {:ok, "TEXT"} + end + + defmodule HandleInBinarySock do + use NoopSock + def handle_in(_data, state), do: {:push, {:binary, "BINARY"}, state} + end + + test "can return a binary frame", context do + client = tcp_client(context) + http1_handshake(client, HandleInBinarySock) + + send_text_frame(client, "OK") + + assert recv_binary_frame(client) == {:ok, "BINARY"} + end + + defmodule HandleInPingSock do + use NoopSock + def handle_in(_data, state), do: {:push, {:ping, "PING"}, state} + end + + test "can return a ping frame", context do + client = tcp_client(context) + http1_handshake(client, HandleInPingSock) + + send_text_frame(client, "OK") + + assert recv_ping_frame(client) == {:ok, "PING"} + end + + defmodule HandleInPongSock do + use NoopSock + def handle_in(_data, state), do: {:push, {:pong, "PONG"}, state} + end + + test "can return a pong frame", context do + client = tcp_client(context) + http1_handshake(client, HandleInPongSock) + + send_text_frame(client, "OK") + + assert recv_pong_frame(client) == {:ok, "PONG"} + end + + defmodule HandleInCloseSock do + use NoopSock + def handle_in(_data, state), do: {:stop, :normal, state} + end + + test "can close a connection by returning a stop tuple", context do + client = tcp_client(context) + http1_handshake(client, HandleInCloseSock) + + send_text_frame(client, "OK") + + assert recv_connection_close_frame(client) == {:ok, <<1000::16>>} + assert connection_closed_for_reading?(client) + end + end + + describe "handle_control" do + defmodule HandleControlNoImplSock do + use NoopSock + def handle_in({data, opcode: opcode}, state), do: {:push, {opcode, data}, state} + end + + test "callback is optional", context do + client = tcp_client(context) + http1_handshake(client, HandleControlNoImplSock) + + send_ping_frame(client, "OK") + assert recv_pong_frame(client) + + # Test that the connection is still alive + send_text_frame(client, "OK") + assert recv_text_frame(client) == {:ok, "OK"} + end + + defmodule HandleControlEchoSock do + use NoopSock + def handle_control({data, opcode: opcode}, state), do: {:push, {opcode, data}, state} + end + + test "can receive a ping frame", context do + client = tcp_client(context) + http1_handshake(client, HandleControlEchoSock) + + send_ping_frame(client, "OK") + _ = recv_pong_frame(client) + + assert recv_ping_frame(client) == {:ok, "OK"} + end + + test "can receive a pong frame", context do + client = tcp_client(context) + http1_handshake(client, HandleControlEchoSock) + + send_pong_frame(client, "OK") + + assert recv_pong_frame(client) == {:ok, "OK"} + end + + defmodule HandleControlStateSock do + use NoopSock + def init(_opts), do: {:ok, []} + + def handle_control({"dump", opcode: :ping} = data, state), + do: {:push, {:ping, inspect(state)}, [data | state]} + + def handle_control(data, state), do: {:ok, [data | state]} + end + + test "can return an ok tuple and update state", context do + client = tcp_client(context) + http1_handshake(client, HandleControlStateSock) + + send_ping_frame(client, "OK") + _ = recv_pong_frame(client) + send_ping_frame(client, "dump") + _ = recv_pong_frame(client) + + {:ok, response} = recv_ping_frame(client) + assert response == inspect([{"OK", opcode: :ping}]) + end + + test "can return a push tuple and update state", context do + client = tcp_client(context) + http1_handshake(client, HandleControlStateSock) + + send_ping_frame(client, "dump") + _ = recv_pong_frame(client) + _ = recv_ping_frame(client) + send_ping_frame(client, "dump") + _ = recv_pong_frame(client) + + {:ok, response} = recv_ping_frame(client) + assert response == inspect([{"dump", opcode: :ping}]) + end + + defmodule HandleControlReplyStateSock do + use NoopSock + def init(_opts), do: {:ok, []} + + def handle_control({"dump", opcode: :ping} = data, state), + do: {:reply, :ok, {:ping, inspect(state)}, [data | state]} + + def handle_control(data, state), do: {:ok, [data | state]} + end + + test "can return a reply tuple and update state", context do + client = tcp_client(context) + http1_handshake(client, HandleControlReplyStateSock) + + send_ping_frame(client, "dump") + _ = recv_pong_frame(client) + _ = recv_ping_frame(client) + send_ping_frame(client, "dump") + _ = recv_pong_frame(client) + + {:ok, response} = recv_ping_frame(client) + assert response == inspect([{"dump", opcode: :ping}]) + end + + defmodule HandleControlTextSock do + use NoopSock + def handle_control(_data, state), do: {:push, {:text, "TEXT"}, state} + end + + test "can return a text frame", context do + client = tcp_client(context) + http1_handshake(client, HandleControlTextSock) + + send_ping_frame(client, "OK") + _ = recv_pong_frame(client) + + assert recv_text_frame(client) == {:ok, "TEXT"} + end + + defmodule HandleControlBinarySock do + use NoopSock + def handle_control(_data, state), do: {:push, {:binary, "BINARY"}, state} + end + + test "can return a binary frame", context do + client = tcp_client(context) + http1_handshake(client, HandleControlBinarySock) + + send_ping_frame(client, "OK") + _ = recv_pong_frame(client) + + assert recv_binary_frame(client) == {:ok, "BINARY"} + end + + defmodule HandleControlPingSock do + use NoopSock + def handle_control(_data, state), do: {:push, {:ping, "PING"}, state} + end + + test "can return a ping frame", context do + client = tcp_client(context) + http1_handshake(client, HandleControlPingSock) + + send_ping_frame(client, "OK") + _ = recv_pong_frame(client) + + assert recv_ping_frame(client) == {:ok, "PING"} + end + + defmodule HandleControlPongSock do + use NoopSock + def handle_control(_data, state), do: {:push, {:pong, "PONG"}, state} + end + + test "can return a pong frame", context do + client = tcp_client(context) + http1_handshake(client, HandleControlPongSock) + + send_ping_frame(client, "OK") + _ = recv_pong_frame(client) + + assert recv_pong_frame(client) == {:ok, "PONG"} + end + + defmodule HandleControlCloseSock do + use NoopSock + def handle_control(_data, state), do: {:stop, :normal, state} + end + + test "can close a connection by returning a stop tuple", context do + client = tcp_client(context) + http1_handshake(client, HandleControlCloseSock) + + send_ping_frame(client, "OK") + _ = recv_pong_frame(client) + + assert recv_connection_close_frame(client) == {:ok, <<1000::16>>} + assert connection_closed_for_reading?(client) + end + end + + describe "handle_info" do + defmodule HandleInfoStateSock do + use NoopSock + def init(_opts), do: {:ok, []} + def handle_in(_data, state), do: {:push, {:text, :erlang.pid_to_list(self())}, state} + def handle_info("dump" = data, state), do: {:push, {:text, inspect(state)}, [data | state]} + def handle_info(data, state), do: {:ok, [data | state]} + end + + test "can return an ok tuple and update state", context do + client = tcp_client(context) + http1_handshake(client, HandleInfoStateSock) + + send_text_frame(client, "whoami") + {:ok, pid} = recv_text_frame(client) + pid = pid |> String.to_charlist() |> :erlang.list_to_pid() + + Process.send(pid, "OK", []) + Process.send(pid, "dump", []) + + {:ok, response} = recv_text_frame(client) + assert response == inspect(["OK"]) + end + + test "can return a push tuple and update state", context do + client = tcp_client(context) + http1_handshake(client, HandleInfoStateSock) + + send_text_frame(client, "whoami") + {:ok, pid} = recv_text_frame(client) + pid = pid |> String.to_charlist() |> :erlang.list_to_pid() + + Process.send(pid, "dump", []) + _ = recv_text_frame(client) + Process.send(pid, "dump", []) + + {:ok, response} = recv_text_frame(client) + assert response == inspect(["dump"]) + end + + defmodule HandleInfoTextSock do + use NoopSock + def handle_in(_data, state), do: {:push, {:text, :erlang.pid_to_list(self())}, state} + def handle_info(_data, state), do: {:push, {:text, "TEXT"}, state} + end + + test "can return a text frame", context do + client = tcp_client(context) + http1_handshake(client, HandleInfoTextSock) + + send_text_frame(client, "whoami") + {:ok, pid} = recv_text_frame(client) + pid = pid |> String.to_charlist() |> :erlang.list_to_pid() + Process.send(pid, "OK", []) + + assert recv_text_frame(client) == {:ok, "TEXT"} + end + + defmodule HandleInfoBinarySock do + use NoopSock + def handle_in(_data, state), do: {:push, {:text, :erlang.pid_to_list(self())}, state} + def handle_info(_data, state), do: {:push, {:binary, "BINARY"}, state} + end + + test "can return a binary frame", context do + client = tcp_client(context) + http1_handshake(client, HandleInfoBinarySock) + + send_text_frame(client, "whoami") + {:ok, pid} = recv_text_frame(client) + pid = pid |> String.to_charlist() |> :erlang.list_to_pid() + Process.send(pid, "OK", []) + + assert recv_binary_frame(client) == {:ok, "BINARY"} + end + + defmodule HandleInfoPingSock do + use NoopSock + def handle_in(_data, state), do: {:push, {:text, :erlang.pid_to_list(self())}, state} + def handle_info(_data, state), do: {:push, {:ping, "PING"}, state} + end + + test "can return a ping frame", context do + client = tcp_client(context) + http1_handshake(client, HandleInfoPingSock) + + send_text_frame(client, "whoami") + {:ok, pid} = recv_text_frame(client) + pid = pid |> String.to_charlist() |> :erlang.list_to_pid() + Process.send(pid, "OK", []) + + assert recv_ping_frame(client) == {:ok, "PING"} + end + + defmodule HandleInfoPongSock do + use NoopSock + def handle_in(_data, state), do: {:push, {:text, :erlang.pid_to_list(self())}, state} + def handle_info(_data, state), do: {:push, {:pong, "PONG"}, state} + end + + test "can return a pong frame", context do + client = tcp_client(context) + http1_handshake(client, HandleInfoPongSock) + + send_text_frame(client, "whoami") + {:ok, pid} = recv_text_frame(client) + pid = pid |> String.to_charlist() |> :erlang.list_to_pid() + Process.send(pid, "OK", []) + + assert recv_pong_frame(client) == {:ok, "PONG"} + end + + defmodule HandleInfoCloseSock do + use NoopSock + def handle_in(_data, state), do: {:push, {:text, :erlang.pid_to_list(self())}, state} + def handle_info(_data, state), do: {:stop, :normal, state} + end + + test "can close a connection by returning a stop tuple", context do + client = tcp_client(context) + http1_handshake(client, HandleInfoCloseSock) + + send_text_frame(client, "whoami") + {:ok, pid} = recv_text_frame(client) + pid = pid |> String.to_charlist() |> :erlang.list_to_pid() + Process.send(pid, "OK", []) + + assert recv_connection_close_frame(client) == {:ok, <<1000::16>>} + assert connection_closed_for_reading?(client) + end + end + + describe "terminate" do + setup do + Process.register(self(), __MODULE__) + :ok + end + + def send(msg), do: send(__MODULE__, msg) + + defmodule TerminateSock do + use NoopSock + def handle_in({"normal", opcode: :text}, state), do: {:stop, :normal, state} + def terminate(reason, _state), do: WebSocketSockTest.send(reason) + end + + test "is called with :normal on a normal connection shutdown", context do + client = tcp_client(context) + http1_handshake(client, TerminateSock) + + # Get the sock to tell bandit to shut down + send_text_frame(client, "normal") + + assert_receive :normal + end + + test "is called with :remote on a normal remote shutdown", context do + client = tcp_client(context) + http1_handshake(client, TerminateSock) + + send_connection_close_frame(client, 1000) + + assert_receive :remote + end + + test "is called with {:error, reason} on a protocol error", context do + client = tcp_client(context) + http1_handshake(client, TerminateSock) + + :gen_tcp.close(client) + + assert_receive {:error, :closed} + end + + @tag capture_log: true + test "is called with :timeout on a timeout", context do + client = tcp_client(context) + http1_handshake(client, TerminateSock) + + assert_receive :timeout, 1500 + end + end + + # Simple WebSocket client + + def tcp_client(context) do + {:ok, socket} = :gen_tcp.connect('localhost', context[:port], active: false, mode: :binary) + + socket + end + + def http1_handshake(client, module, params \\ []) do + params = params |> Keyword.put(:sock, module) + + :gen_tcp.send(client, """ + GET /?#{URI.encode_query(params)} HTTP/1.1\r + Host: server.example.com\r + Upgrade: websocket\r + Connection: Upgrade\r + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r + Sec-WebSocket-Version: 13\r + \r + """) + + {:ok, response} = :gen_tcp.recv(client, 234) + + [ + "HTTP/1.1 101 Switching Protocols", + "cache-control: max-age=0, private, must-revalidate", + "connection: Upgrade", + "date: " <> _date, + "sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", + "server: Cowboy", + "upgrade: websocket", + "", + "" + ] = String.split(response, "\r\n") + end + + def connection_closed_for_reading?(client) do + :gen_tcp.recv(client, 0) == {:error, :closed} + end + + def connection_closed_for_writing?(client) do + :gen_tcp.send(client, <<>>) == {:error, :closed} + end + + def recv_text_frame(client) do + {:ok, 0x8, 0x1, body} = recv_frame(client) + {:ok, body} + end + + def recv_binary_frame(client) do + {:ok, 0x8, 0x2, body} = recv_frame(client) + {:ok, body} + end + + def recv_connection_close_frame(client) do + {:ok, 0x8, 0x8, body} = recv_frame(client) + {:ok, body} + end + + def recv_ping_frame(client) do + {:ok, 0x8, 0x9, body} = recv_frame(client) + {:ok, body} + end + + def recv_pong_frame(client) do + {:ok, 0x8, 0xA, body} = recv_frame(client) + {:ok, body} + end + + defp recv_frame(client) do + {:ok, header} = :gen_tcp.recv(client, 2) + <> = header + + {:ok, data} = + case length do + 0 -> + {:ok, <<>>} + + 126 -> + {:ok, <>} = :gen_tcp.recv(client, 2) + :gen_tcp.recv(client, length) + + 127 -> + {:ok, <>} = :gen_tcp.recv(client, 8) + :gen_tcp.recv(client, length) + + length -> + :gen_tcp.recv(client, length) + end + + {:ok, flags, opcode, data} + end + + def send_continuation_frame(client, data, flags \\ 0x8) do + send_frame(client, flags, 0x0, data) + end + + def send_text_frame(client, data, flags \\ 0x8) do + send_frame(client, flags, 0x1, data) + end + + def send_binary_frame(client, data, flags \\ 0x8) do + send_frame(client, flags, 0x2, data) + end + + def send_connection_close_frame(client, reason) do + send_frame(client, 0x8, 0x8, <>) + end + + def send_ping_frame(client, data) do + send_frame(client, 0x8, 0x9, data) + end + + def send_pong_frame(client, data) do + send_frame(client, 0x8, 0xA, data) + end + + defp send_frame(client, flags, opcode, data) do + mask = :rand.uniform(1_000_000) + masked_data = mask(data, mask) + + mask_flag_and_size = + case byte_size(masked_data) do + size when size <= 125 -> <<1::1, size::7>> + size when size <= 65_535 -> <<1::1, 126::7, size::16>> + size -> <<1::1, 127::7, size::64>> + end + + :gen_tcp.send(client, [<>, mask_flag_and_size, <>, masked_data]) + end + + # Note that masking is an involution, so we don't need a separate unmask function + defp mask(payload, mask, acc \\ <<>>) + + defp mask(payload, mask, acc) when is_integer(mask), do: mask(payload, <>, acc) + + defp mask(<>, <>, acc) do + mask(rest, mask, acc <> <>) + end + + defp mask(<>, <>, acc) do + mask(rest, <>, acc <> <>) + end + + defp mask(<<>>, _mask, acc), do: acc +end +