Skip to content

Commit

Permalink
Implement Plug-mediated WebSocket upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
mtrudel committed Oct 31, 2022
1 parent 9518a3a commit aa6211a
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 129 deletions.
19 changes: 11 additions & 8 deletions lib/phoenix/endpoint.ex
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,8 @@ defmodule Phoenix.Endpoint do
]
end

plug :socket_dispatch

# Compile after the debugger so we properly wrap it.
@before_compile Phoenix.Endpoint
end
Expand Down Expand Up @@ -626,10 +628,10 @@ defmodule Phoenix.Endpoint do

dispatches =
for {path, socket, socket_opts} <- sockets,
{path, type, conn_ast, socket, opts} <- socket_paths(module, path, socket, socket_opts) do
{path, plug, conn_ast, plug_opts} <- socket_paths(module, path, socket, socket_opts) do
quote do
defp do_handler(unquote(path), conn, _opts) do
{unquote(type), unquote(conn_ast), unquote(socket), unquote(Macro.escape(opts))}
defp do_handler(unquote(path), conn) do
unquote(plug).call(unquote(conn_ast), unquote(Macro.escape(plug_opts)))
end
end
end
Expand Down Expand Up @@ -659,9 +661,9 @@ defmodule Phoenix.Endpoint do
def __sockets__, do: unquote(Macro.escape(sockets))

@doc false
def __handler__(%{path_info: path} = conn, opts), do: do_handler(path, conn, opts)
def socket_dispatch(%{path_info: path} = conn, _opts), do: do_handler(path, conn)
unquote(dispatches)
defp do_handler(_path, conn, opts), do: {:plug, conn, __MODULE__, opts}
defp do_handler(_path, conn), do: conn
end
end

Expand All @@ -673,8 +675,9 @@ defmodule Phoenix.Endpoint do
paths =
if websocket do
config = Phoenix.Socket.Transport.load_config(websocket, Phoenix.Transports.WebSocket)
plug_init = {endpoint, socket, config}
{conn_ast, match_path} = socket_path(path, config)
[{match_path, :websocket, conn_ast, socket, config} | paths]
[{match_path, Phoenix.Transports.WebSocket, conn_ast, plug_init} | paths]
else
paths
end
Expand All @@ -684,7 +687,7 @@ defmodule Phoenix.Endpoint do
config = Phoenix.Socket.Transport.load_config(longpoll, Phoenix.Transports.LongPoll)
plug_init = {endpoint, socket, config}
{conn_ast, match_path} = socket_path(path, config)
[{match_path, :plug, conn_ast, Phoenix.Transports.LongPoll, plug_init} | paths]
[{match_path, Phoenix.Transports.LongPoll, conn_ast, plug_init} | paths]
else
paths
end
Expand Down Expand Up @@ -909,7 +912,7 @@ defmodule Phoenix.Endpoint do
"""
defmacro socket(path, module, opts \\ []) do
module = Macro.expand(module, %{__CALLER__ | function: {:__handler__, 2}})
module = Macro.expand(module, %{__CALLER__ | function: {:socket_dispatch, 2}})

quote do
@phoenix_sockets {unquote(path), unquote(module), unquote(opts)}
Expand Down
16 changes: 11 additions & 5 deletions lib/phoenix/endpoint/cowboy2_adapter.ex
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ defmodule Phoenix.Endpoint.Cowboy2Adapter do

# Ranch options are read from the top, so we keep the user opts first.
opts = :proplists.delete(:port, opts) ++ [port: port_to_integer(port), otp_app: otp_app]
child_spec(scheme, endpoint, opts)
child_spec(scheme, endpoint, opts, config[:code_reloader])
end

{refs, child_specs} = Enum.unzip(refs_and_specs)
Expand All @@ -75,15 +75,21 @@ defmodule Phoenix.Endpoint.Cowboy2Adapter do
end
end

defp child_spec(scheme, endpoint, config) do
defp child_spec(scheme, endpoint, config, code_reloader?) do
if scheme == :https do
Application.ensure_all_started(:ssl)
end

dispatches = [{:_, Phoenix.Endpoint.Cowboy2Handler, {endpoint, endpoint.init([])}}]
config = Keyword.put_new(config, :dispatch, [{:_, dispatches}])
ref = Module.concat(endpoint, scheme |> Atom.to_string() |> String.upcase())
spec = Plug.Cowboy.child_spec(ref: ref, scheme: scheme, plug: {endpoint, []}, options: config)

plug =
if code_reloader? do
{Phoenix.Endpoint.SyncCodeReloadPlug, {endpoint, []}}
else
{endpoint, []}
end

spec = Plug.Cowboy.child_spec(ref: ref, scheme: scheme, plug: plug, options: config)
spec = update_in(spec.start, &{__MODULE__, :start_link, [scheme, endpoint, &1]})
{ref, spec}
end
Expand Down
114 changes: 9 additions & 105 deletions lib/phoenix/endpoint/cowboy2_handler.ex
Original file line number Diff line number Diff line change
@@ -1,109 +1,11 @@
defmodule Phoenix.Endpoint.Cowboy2Handler do
@moduledoc false

if Code.ensure_loaded?(:cowboy_websocket) and
function_exported?(:cowboy_websocket, :behaviour_info, 1) do
@behaviour :cowboy_websocket
end

@connection Plug.Cowboy.Conn
@already_sent {:plug_conn, :sent}

# Note we keep the websocket state as [handler | state]
# to avoid conflicts with {endpoint, opts}.
def init(req, {endpoint, opts}) do
init(@connection.conn(req), endpoint, opts, true)
end

defp init(conn, endpoint, opts, retry?) do
try do
case endpoint.__handler__(conn, opts) do
{:websocket, conn, handler, opts} ->
case Phoenix.Transports.WebSocket.connect(conn, endpoint, handler, opts) do
{:ok, %Plug.Conn{adapter: {@connection, req}} = conn, state} ->
cowboy_opts =
opts
|> Enum.flat_map(fn
{:timeout, timeout} -> [idle_timeout: timeout]
{:compress, _} = opt -> [opt]
{:max_frame_size, _} = opt -> [opt]
_other -> []
end)
|> Map.new()

handler_opts = Keyword.take(opts, [:fullsweep_after])
triplet = {handler, handler_opts, state}
{:cowboy_websocket, copy_resp_headers(conn, req), triplet, cowboy_opts}

{:error, %Plug.Conn{adapter: {@connection, req}} = conn} ->
{:ok, copy_resp_headers(conn, req), {handler, opts}}
end

{:plug, conn, handler, opts} ->
%{adapter: {@connection, req}} =
conn
|> handler.call(opts)
|> maybe_send(handler)

{:ok, req, {handler, opts}}
end
catch
kind, reason ->
case __STACKTRACE__ do
# Maybe the handler is not available because the code is being recompiled.
# Sync with the code reloader and retry once.
[{^endpoint, :__handler__, _, _} | _] when reason == :undef and retry? ->
Phoenix.CodeReloader.sync()
init(conn, endpoint, opts, false)

stacktrace ->
exit_on_error(kind, reason, stacktrace, {endpoint, :call, [conn, opts]})
end
after
receive do
@already_sent -> :ok
after
0 -> :ok
end
end
end

defp maybe_send(%Plug.Conn{state: :unset}, _plug), do: raise(Plug.Conn.NotSentError)
defp maybe_send(%Plug.Conn{state: :set} = conn, _plug), do: Plug.Conn.send_resp(conn)
defp maybe_send(%Plug.Conn{} = conn, _plug), do: conn

defp maybe_send(other, plug) do
raise "Cowboy2 adapter expected #{inspect(plug)} to return Plug.Conn but got: " <>
inspect(other)
end

defp exit_on_error(
:error,
%Plug.Conn.WrapperError{kind: kind, reason: reason, stack: stack},
_stack,
call
) do
exit_on_error(kind, reason, stack, call)
end
@behaviour :cowboy_websocket

defp exit_on_error(:error, value, stack, call) do
exception = Exception.normalize(:error, value, stack)
:erlang.raise(:exit, {{exception, stack}, call}, [])
end

defp exit_on_error(:throw, value, stack, call) do
:erlang.raise(:exit, {{{:nocatch, value}, stack}, call}, [])
end

defp exit_on_error(:exit, value, _stack, call) do
:erlang.raise(:exit, {value, call}, [])
end

defp copy_resp_headers(%Plug.Conn{} = conn, req) do
Enum.reduce(conn.resp_headers, req, fn {key, val}, acc ->
:cowboy_req.set_resp_header(key, val, acc)
end)
end
# We never actually call this; it's just here to quell compiler warnings
@impl true
def init(req, state), do: {:cowboy_websocket, req, state}

defp handle_reply(handler, {:ok, state}), do: {:ok, [handler | state]}
defp handle_reply(handler, {:push, data, state}), do: {:reply, data, [handler | state]}
Expand All @@ -125,8 +27,7 @@ defmodule Phoenix.Endpoint.Cowboy2Handler do
handle_reply(handler, reply)
end

## Websocket callbacks

@impl true
def websocket_init({handler, process_flags, state}) do
for {key, value} <- process_flags do
:erlang.process_flag(key, value)
Expand All @@ -136,6 +37,7 @@ defmodule Phoenix.Endpoint.Cowboy2Handler do
{:ok, [handler | state]}
end

@impl true
def websocket_handle({opcode, payload}, [handler | state]) when opcode in [:text, :binary] do
handle_reply(handler, handler.handle_in({payload, opcode: opcode}, state))
end
Expand All @@ -152,11 +54,13 @@ defmodule Phoenix.Endpoint.Cowboy2Handler do
{:ok, handler_state}
end

@impl true
def websocket_info(message, [handler | state]) do
handle_reply(handler, handler.handle_info(message, state))
end

def terminate(_reason, _req, {_handler, _state}) do
@impl true
def terminate(_reason, _req, {_handler, _process_flags, _state}) do
:ok
end

Expand Down
36 changes: 36 additions & 0 deletions lib/phoenix/endpoint/sync_code_reload_plug.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
defmodule Phoenix.Endpoint.SyncCodeReloadPlug do
@moduledoc ~S"""
Wraps an Endpoint, attempting to sync with Phoenix's code reloader if
an exception is raising which indicates that we may be in the middle of a reload.
We detect this by looking at the raised exception and seeing if it indicates
that the endpoint is not defined. This indicates that the code reloader may be
mid way through a compile, and that we should attempt to retry the request
after the compile has completed. This is also why this must be implemented in
a separate module (one that is not recompiled in a typical code reload cycle),
since otherwise it may be the case that the endpoint itself is not defined.
"""

@behaviour Plug

def init({endpoint, opts}), do: {endpoint, endpoint.init(opts)}

def call(conn, {endpoint, opts}), do: do_call(conn, endpoint, opts, true)

defp do_call(conn, endpoint, opts, retry?) do
try do
endpoint.call(conn, opts)
rescue
exception in [UndefinedFunctionError] ->
case exception do
%UndefinedFunctionError{module: ^endpoint} when retry? ->
# Sync with the code reloader and retry once
Phoenix.CodeReloader.sync()
do_call(conn, endpoint, opts, false)

exception ->
reraise(exception, __STACKTRACE__)
end
end
end
end
59 changes: 48 additions & 11 deletions lib/phoenix/transports/websocket.ex
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
defmodule Phoenix.Transports.WebSocket do
@moduledoc false
@behaviour Plug

import Plug.Conn

alias Phoenix.Socket.{V1, V2, Transport}

def default_config() do
Expand All @@ -13,36 +17,69 @@ defmodule Phoenix.Transports.WebSocket do
]
end

def connect(%{method: "GET"} = conn, endpoint, handler, opts) do
def init(opts), do: opts

def call(%{method: "GET"} = conn, {endpoint, handler, opts}) do
conn
|> Plug.Conn.fetch_query_params()
|> fetch_query_params()
|> Transport.code_reload(endpoint, opts)
|> Transport.transport_log(opts[:transport_log])
|> Transport.force_ssl(handler, endpoint, opts)
|> Transport.check_origin(handler, endpoint, opts)
|> Transport.check_subprotocols(opts[:subprotocols])
|> case do
%{halted: true} = conn ->
{:error, conn}
conn

%{params: params} = conn ->
keys = Keyword.get(opts, :connect_info, [])
connect_info = Transport.connect_info(conn, endpoint, keys)
config = %{endpoint: endpoint, transport: :websocket, options: opts, params: params, connect_info: connect_info}

config = %{
endpoint: endpoint,
transport: :websocket,
options: opts,
params: params,
connect_info: connect_info
}

cowboy_opts =
opts
|> Enum.flat_map(fn
{:timeout, timeout} -> [idle_timeout: timeout]
{:compress, _} = opt -> [opt]
{:max_frame_size, _} = opt -> [opt]
_other -> []
end)
|> Map.new()

process_flags =
opts
|> Keyword.take([:fullsweep_after])
|> Map.new()

case handler.connect(config) do
{:ok, state} -> {:ok, conn, state}
:error -> {:error, Plug.Conn.send_resp(conn, 403, "")}
{:ok, state} ->
# Cures a Cowboy race condition where it doesn't see our declared websocket_init/1
_ = Code.ensure_loaded?(Phoenix.Endpoint.Cowboy2Handler)
handler_args = {handler, process_flags, state}
upgrade_args = {Phoenix.Endpoint.Cowboy2Handler, handler_args, cowboy_opts}

conn
|> upgrade_adapter(:websocket, upgrade_args)
|> halt()

:error ->
send_resp(conn, 403, "")

{:error, reason} ->
{m, f, args} = opts[:error_handler]
{:error, apply(m, f, [conn, reason | args])}
apply(m, f, [conn, reason | args])
end
end
end

def connect(conn, _, _, _) do
{:error, Plug.Conn.send_resp(conn, 400, "")}
end
def call(conn, _), do: send_resp(conn, 400, "")

def handle_error(conn, _reason), do: Plug.Conn.send_resp(conn, 403, "")
def handle_error(conn, _reason), do: send_resp(conn, 403, "")
end

0 comments on commit aa6211a

Please sign in to comment.