Skip to content

Commit

Permalink
WebSocket support (tests pending)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtrudel committed Oct 15, 2022
1 parent 8c69a4f commit 4357226
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 8 deletions.
6 changes: 3 additions & 3 deletions lib/plug/cowboy/conn.ex
Expand Up @@ -91,9 +91,9 @@ defmodule Plug.Cowboy.Conn do
end

@impl true
def upgrade(_req, _protocol, _opts) do
{:error, :not_supported}
end
def upgrade(req, :websocket, opts), do: {:ok, req |> Map.put(:upgrade, {:websocket, opts})}

def upgrade(_req, _upgrade, _opts), do: {:error, :not_supported}

@impl true
def push(req, path, headers) do
Expand Down
117 changes: 112 additions & 5 deletions lib/plug/cowboy/handler.ex
@@ -1,18 +1,42 @@
defmodule Plug.Cowboy.Handler do
@moduledoc false

if Code.ensure_loaded?(:cowboy_websocket) and
function_exported?(:cowboy_websocket, :behaviour_info, 1) do
@behaviour :cowboy_websocket
end

@connection Plug.Cowboy.Conn
@already_sent {:plug_conn, :sent}

def init(req, {plug, opts}) do
conn = @connection.conn(req)

try do
%{adapter: {@connection, req}} =
conn
|> plug.call(opts)
|> maybe_send(plug)
conn
|> plug.call(opts)
|> maybe_send(plug)
|> case do
%Plug.Conn{adapter: {@connection, %{upgrade: {:websocket, websocket_opts}} = req}} = conn ->
{handler, state, connection_opts} = websocket_opts

{:ok, req, {plug, opts}}
cowboy_opts =
connection_opts
|> Enum.flat_map(fn
{:timeout, timeout} -> [idle_timeout: timeout]
{:compress, _} = opt -> [opt]
{:max_frame_size, _} = opt -> [opt]
_other -> []
end)
|> Map.new()

handler_opts = Keyword.take(connection_opts, [:fullsweep_after])
triplet = {handler, handler_opts, state}
{:cowboy_websocket, copy_resp_headers(conn, req), triplet, cowboy_opts}

%Plug.Conn{adapter: {@connection, req}} ->
{:ok, req, {plug, opts}}
end
catch
kind, reason ->
exit_on_error(kind, reason, __STACKTRACE__, {plug, :call, [conn, opts]})
Expand All @@ -25,6 +49,12 @@ defmodule Plug.Cowboy.Handler do
end
end

defp copy_resp_headers(%Plug.Conn{} = conn, req) do
Enum.reduce(conn.resp_headers, req, fn {key, val}, acc ->
:cowboy_req.set_resp_header(key, val, acc)
end)
end

defp exit_on_error(
:error,
%Plug.Conn.WrapperError{kind: kind, reason: reason, stack: stack},
Expand Down Expand Up @@ -55,4 +85,81 @@ defmodule Plug.Cowboy.Handler do
raise "Cowboy2 adapter expected #{inspect(plug)} to return Plug.Conn but got: " <>
inspect(other)
end

defp handle_reply(handler, {:ok, state}), do: {:ok, [handler | state]}
defp handle_reply(handler, {:push, data, state}), do: {:reply, data, [handler | state]}

defp handle_reply(handler, {:reply, _status, data, state}),
do: {:reply, data, [handler | state]}

defp handle_reply(handler, {:stop, _reason, state}), do: {:stop, [handler | state]}

defp handle_control_frame(payload_with_opts, handler_state) do
[handler | state] = handler_state

reply =
if function_exported?(handler, :handle_control, 2) do
handler.handle_control(payload_with_opts, state)
else
{:ok, state}
end

handle_reply(handler, reply)
end

## Websocket callbacks

def websocket_init({handler, process_flags, state}) do
for {key, value} <- process_flags do
:erlang.process_flag(key, value)
end

{:ok, state} = handler.init(state)
{:ok, [handler | state]}
end

def websocket_handle({opcode, payload}, [handler | state]) when opcode in [:text, :binary] do
handle_reply(handler, handler.handle_in({payload, opcode: opcode}, state))
end

def websocket_handle({opcode, payload}, handler_state) when opcode in [:ping, :pong] do
handle_control_frame({payload, opcode: opcode}, handler_state)
end

def websocket_handle(opcode, handler_state) when opcode in [:ping, :pong] do
handle_control_frame({nil, opcode: opcode}, handler_state)
end

def websocket_handle(_other, handler_state) do
{:ok, handler_state}
end

def websocket_info(message, [handler | state]) do
handle_reply(handler, handler.handle_info(message, state))
end

def terminate(_reason, _req, {_handler, _state}) do
:ok
end

def terminate({:error, :closed}, _req, [handler | state]) do
handler.terminate(:closed, state)
end

def terminate({:remote, :closed}, _req, [handler | state]) do
handler.terminate(:closed, state)
end

def terminate({:remote, code, _}, _req, [handler | state])
when code in 1000..1003 or code in 1005..1011 or code == 1015 do
handler.terminate(:closed, state)
end

def terminate(:remote, _req, [handler | state]) do
handler.terminate(:closed, state)
end

def terminate(reason, _req, [handler | state]) do
handler.terminate(reason, state)
end
end

0 comments on commit 4357226

Please sign in to comment.