Skip to content

Commit

Permalink
Allow using socket/connect in another process (#5488)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hermanverschooten authored Jun 16, 2023
1 parent fd6663b commit a2b92c0
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 17 deletions.
66 changes: 49 additions & 17 deletions lib/phoenix/test/channel_test.ex
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 197,7 @@ defmodule Phoenix.ChannelTest do
Use this function when you want to create a blank socket
to pass to functions like `UserSocket.connect/3`.
Otherwise, use `socket/3` if you want to build a socket with
Otherwise, use `socket/4` if you want to build a socket with
existing id and assigns.
## Examples
Expand All @@ -206,7 206,7 @@ defmodule Phoenix.ChannelTest do
"""
defmacro socket(socket_module) do
socket(socket_module, nil, [], __CALLER__)
socket(socket_module, nil, [], [], __CALLER__)
end

@doc """
Expand All @@ -216,19 216,35 @@ defmodule Phoenix.ChannelTest do
socket(MyApp.UserSocket, "user_id", %{some: :assign})
If you need to access the socket in another process than the test process,
you can give the `pid` of the test process in the 4th argument.
## Examples
test "connect in a task" do
pid = self()
task = Task.async(fn ->
socket = socket(MyApp.UserSocket, "user_id", %{some: :assign}, test_process: pid)
broadcast_from!(socket, "default", %{"foo" => "bar"})
assert_push "default", %{"foo" => "bar"}
end)
Task.await(task)
end
"""
defmacro socket(socket_module, socket_id, socket_assigns) do
socket(socket_module, socket_id, socket_assigns, __CALLER__)
defmacro socket(socket_module, socket_id, socket_assigns, options \\ []) do
socket(socket_module, socket_id, socket_assigns, options, __CALLER__)
end

defp socket(module, id, assigns, caller) do
defp socket(module, id, assigns, options, caller) do
if endpoint = Module.get_attribute(caller.module, :endpoint) do
quote do
unquote(__MODULE__).__socket__(
unquote(module),
unquote(id),
unquote(assigns),
unquote(endpoint)
unquote(endpoint),
unquote(options)
)
end
else
Expand All @@ -237,15 253,15 @@ defmodule Phoenix.ChannelTest do
end

@doc false
def __socket__(socket, id, assigns, endpoint) do
def __socket__(socket, id, assigns, endpoint, options) do
%Socket{
assigns: Enum.into(assigns, %{}),
endpoint: endpoint,
handler: socket || first_socket!(endpoint),
id: id,
pubsub_server: endpoint.config(:pubsub_server),
serializer: NoopSerializer,
transport: {__MODULE__, fetch_test_supervisor!()},
transport: {__MODULE__, fetch_test_supervisor!(options)},
transport_pid: self()
}
end
Expand All @@ -257,8 273,8 @@ defmodule Phoenix.ChannelTest do
end
end

defp fetch_test_supervisor!() do
case ExUnit.OnExitHandler.get_supervisor(self()) do
defp fetch_test_supervisor!(options) do
case ExUnit.OnExitHandler.get_supervisor(Keyword.get(options, :test_process, self())) do
{:ok, nil} ->
opts = [strategy: :one_for_one, max_restarts: 1_000_000, max_seconds: 1]
{:ok, sup} = Supervisor.start_link([], opts)
Expand All @@ -276,13 292,13 @@ defmodule Phoenix.ChannelTest do
@doc false
@deprecated "Phoenix.ChannelTest.socket/0 is deprecated, please call socket/1 instead"
defmacro socket() do
socket(nil, nil, [], __CALLER__)
socket(nil, nil, [], [], __CALLER__)
end

@doc false
@deprecated "Phoenix.ChannelTest.socket/2 is deprecated, please call socket/3 instead"
@deprecated "Phoenix.ChannelTest.socket/2 is deprecated, please call socket/4 instead"
defmacro socket(id, assigns) do
socket(nil, id, assigns, __CALLER__)
socket(nil, id, assigns, [], __CALLER__)
end

@doc """
Expand All @@ -291,21 307,37 @@ defmodule Phoenix.ChannelTest do
Useful for testing UserSocket authentication. Returns
the result of the handler's `connect/3` callback.
"""
defmacro connect(handler, params, connect_info \\ quote(do: %{})) do
defmacro connect(handler, params, options \\ quote(do: [])) do
if endpoint = Module.get_attribute(__CALLER__.module, :endpoint) do
quote do
unquote(__MODULE__).__connect__(unquote(endpoint), unquote(handler), unquote(params), unquote(connect_info))
unquote(__MODULE__).__connect__(
unquote(endpoint),
unquote(handler),
unquote(params),
unquote(options)
)
end
else
raise "module attribute @endpoint not set for socket/2"
end
end

@doc false
def __connect__(endpoint, handler, params, connect_info) do
def __connect__(endpoint, handler, params, options) do
{connect_info, options} =
if is_map(options) do
IO.warn(
"Passing \"connect_info\" directly to connect/3 is deprecated, please pass \"connect_info: ...\" as an option instead"
)

{options, []}
else
Keyword.pop(options, :connect_info, %{})
end

map = %{
endpoint: endpoint,
transport: {__MODULE__, fetch_test_supervisor!()},
transport: {__MODULE__, fetch_test_supervisor!(options)},
options: [serializer: [{NoopSerializer, "~> 1.0.0"}]],
params: __stringify__(params),
connect_info: connect_info
Expand Down
52 changes: 52 additions & 0 deletions test/phoenix/test/channel_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 206,24 @@ defmodule Phoenix.Test.ChannelTest do
} = socket(UserSocket, "user:id", %{hello: :world})
end

test "socket/4" do
pid = self()

task =
Task.async(fn ->
assert %Socket{
id: "user:id",
assigns: %{hello: :world},
endpoint: @endpoint,
pubsub_server: Phoenix.Test.ChannelTest.PubSub,
serializer: Phoenix.ChannelTest.NoopSerializer,
handler: UserSocket
} = socket(UserSocket, "user:id", %{hello: :world}, test_process: pid)
end)

Task.await(task)
end

## join

test "join/3 with success" do
Expand Down Expand Up @@ -341,6 359,21 @@ defmodule Phoenix.Test.ChannelTest do
assert_broadcast "broadcast", %{"foo" => "bar"}
end

test "connects and joins topics directly, from another process" do
pid = self()

task =
Task.async(fn ->
{:ok, socket} = connect(UserSocket, %{}, test_process: pid)
socket = subscribe_and_join!(socket, "foo:ok")
push(socket, "broadcast", %{"foo" => "bar"})
assert socket.id == "123"
assert_broadcast "broadcast", %{"foo" => "bar"}
end)

Task.await(task)
end

test "pushes atom parameter keys as strings" do
{:ok, _, socket} = join(socket(UserSocket), Channel, "foo:ok")

Expand Down Expand Up @@ -368,6 401,25 @@ defmodule Phoenix.Test.ChannelTest do
assert_push "default", %{"foo" => "bar"}
end

test "push broadcasts by default, outside of test process" do
pid = self()

task =
Task.async(fn ->
socket =
subscribe_and_join!(
socket(UserSocket, "user_id", %{some: :assign}, test_process: pid),
Channel,
"foo:ok"
)

broadcast_from!(socket, "default", %{"foo" => "bar"})
assert_push "default", %{"foo" => "bar"}
end)

Task.await(task)
end

test "handles broadcasts and stops" do
Process.flag(:trap_exit, true)
{:ok, _, socket} = subscribe_and_join(socket(UserSocket), Channel, "foo:ok")
Expand Down

0 comments on commit a2b92c0

Please sign in to comment.