From 43572262b1858c9dfeace358bbd03e78d549cdba Mon Sep 17 00:00:00 2001 From: Mat Trudel Date: Sat, 15 Oct 2022 17:58:48 -0400 Subject: [PATCH] WebSocket support (tests pending) --- lib/plug/cowboy/conn.ex | 6 +- lib/plug/cowboy/handler.ex | 117 +++++++++++++++++++++++++++++++++++-- 2 files changed, 115 insertions(+), 8 deletions(-) diff --git a/lib/plug/cowboy/conn.ex b/lib/plug/cowboy/conn.ex index b8b8544..4982606 100644 --- a/lib/plug/cowboy/conn.ex +++ b/lib/plug/cowboy/conn.ex @@ -91,9 +91,9 @@ defmodule Plug.Cowboy.Conn do end @impl true - def upgrade(_req, _protocol, _opts) do - {:error, :not_supported} - end + def upgrade(req, :websocket, opts), do: {:ok, req |> Map.put(:upgrade, {:websocket, opts})} + + def upgrade(_req, _upgrade, _opts), do: {:error, :not_supported} @impl true def push(req, path, headers) do diff --git a/lib/plug/cowboy/handler.ex b/lib/plug/cowboy/handler.ex index e7aed97..5fda7df 100644 --- a/lib/plug/cowboy/handler.ex +++ b/lib/plug/cowboy/handler.ex @@ -1,5 +1,11 @@ defmodule Plug.Cowboy.Handler do @moduledoc false + + if Code.ensure_loaded?(:cowboy_websocket) and + function_exported?(:cowboy_websocket, :behaviour_info, 1) do + @behaviour :cowboy_websocket + end + @connection Plug.Cowboy.Conn @already_sent {:plug_conn, :sent} @@ -7,12 +13,30 @@ defmodule Plug.Cowboy.Handler do conn = @connection.conn(req) try do - %{adapter: {@connection, req}} = - conn - |> plug.call(opts) - |> maybe_send(plug) + conn + |> plug.call(opts) + |> maybe_send(plug) + |> case do + %Plug.Conn{adapter: {@connection, %{upgrade: {:websocket, websocket_opts}} = req}} = conn -> + {handler, state, connection_opts} = websocket_opts - {:ok, req, {plug, opts}} + cowboy_opts = + connection_opts + |> Enum.flat_map(fn + {:timeout, timeout} -> [idle_timeout: timeout] + {:compress, _} = opt -> [opt] + {:max_frame_size, _} = opt -> [opt] + _other -> [] + end) + |> Map.new() + + handler_opts = Keyword.take(connection_opts, [:fullsweep_after]) + triplet = {handler, handler_opts, state} + {:cowboy_websocket, copy_resp_headers(conn, req), triplet, cowboy_opts} + + %Plug.Conn{adapter: {@connection, req}} -> + {:ok, req, {plug, opts}} + end catch kind, reason -> exit_on_error(kind, reason, __STACKTRACE__, {plug, :call, [conn, opts]}) @@ -25,6 +49,12 @@ defmodule Plug.Cowboy.Handler do end end + defp copy_resp_headers(%Plug.Conn{} = conn, req) do + Enum.reduce(conn.resp_headers, req, fn {key, val}, acc -> + :cowboy_req.set_resp_header(key, val, acc) + end) + end + defp exit_on_error( :error, %Plug.Conn.WrapperError{kind: kind, reason: reason, stack: stack}, @@ -55,4 +85,81 @@ defmodule Plug.Cowboy.Handler do raise "Cowboy2 adapter expected #{inspect(plug)} to return Plug.Conn but got: " <> inspect(other) end + + defp handle_reply(handler, {:ok, state}), do: {:ok, [handler | state]} + defp handle_reply(handler, {:push, data, state}), do: {:reply, data, [handler | state]} + + defp handle_reply(handler, {:reply, _status, data, state}), + do: {:reply, data, [handler | state]} + + defp handle_reply(handler, {:stop, _reason, state}), do: {:stop, [handler | state]} + + defp handle_control_frame(payload_with_opts, handler_state) do + [handler | state] = handler_state + + reply = + if function_exported?(handler, :handle_control, 2) do + handler.handle_control(payload_with_opts, state) + else + {:ok, state} + end + + handle_reply(handler, reply) + end + + ## Websocket callbacks + + def websocket_init({handler, process_flags, state}) do + for {key, value} <- process_flags do + :erlang.process_flag(key, value) + end + + {:ok, state} = handler.init(state) + {:ok, [handler | state]} + end + + def websocket_handle({opcode, payload}, [handler | state]) when opcode in [:text, :binary] do + handle_reply(handler, handler.handle_in({payload, opcode: opcode}, state)) + end + + def websocket_handle({opcode, payload}, handler_state) when opcode in [:ping, :pong] do + handle_control_frame({payload, opcode: opcode}, handler_state) + end + + def websocket_handle(opcode, handler_state) when opcode in [:ping, :pong] do + handle_control_frame({nil, opcode: opcode}, handler_state) + end + + def websocket_handle(_other, handler_state) do + {:ok, handler_state} + end + + def websocket_info(message, [handler | state]) do + handle_reply(handler, handler.handle_info(message, state)) + end + + def terminate(_reason, _req, {_handler, _state}) do + :ok + end + + def terminate({:error, :closed}, _req, [handler | state]) do + handler.terminate(:closed, state) + end + + def terminate({:remote, :closed}, _req, [handler | state]) do + handler.terminate(:closed, state) + end + + def terminate({:remote, code, _}, _req, [handler | state]) + when code in 1000..1003 or code in 1005..1011 or code == 1015 do + handler.terminate(:closed, state) + end + + def terminate(:remote, _req, [handler | state]) do + handler.terminate(:closed, state) + end + + def terminate(reason, _req, [handler | state]) do + handler.terminate(reason, state) + end end