Skip to content

Commit dddeb8c

Browse files
yashksaini-coderCopilotacul71
authored
refactor: add shared pubsub test fixtures and wait_for polling helper (#1298)
* refactor: add shared pubsub test fixtures and wait_for polling helper (#378) * refactor: simplify gossipsub_nodes function signature * Update tests/core/pubsub/conftest.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: address review feedback from acul71 - wait_for(): use inspect.isawaitable(result) instead of inspect.iscoroutinefunction(predicate) so lambdas returning coroutines are properly awaited (prevents false positives and "coroutine was never awaited" warnings) - subscribed_mesh: add TODO(#378) on settle_time sleep - add newsfragments/378.internal.rst for towncrier * test(pubsub): add strict option to connected_gossipsub_nodes Default behaviour is unchanged: each node waits for exactly one expected neighbour after dense_connect, which keeps the fixture fast for the common case. Pass strict=True to wait until every node has observed every other expected peer, for topology-sensitive tests that assert exact peer counts or full fanout. Addresses acul71's minor improvement on PR #1298. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: acul71 <34693171+acul71@users.noreply.github.com>
1 parent 7908810 commit dddeb8c

4 files changed

Lines changed: 230 additions & 77 deletions

File tree

newsfragments/378.internal.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add shared pubsub test fixtures (``GossipSubHarness``, ``gossipsub_nodes``, ``connected_gossipsub_nodes``, ``subscribed_mesh``) and reusable polling helpers (``wait_for``, ``wait_for_convergence``) to support the pubsub test suite refactor.

tests/core/pubsub/conftest.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""Shared fixtures and helpers for pubsub tests."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import AsyncIterator
6+
from contextlib import asynccontextmanager
7+
import dataclasses
8+
from typing import Any
9+
10+
import pytest
11+
import trio
12+
13+
from libp2p.abc import IHost
14+
from libp2p.pubsub.gossipsub import GossipSub
15+
from libp2p.pubsub.pubsub import Pubsub
16+
from tests.utils.factories import PubsubFactory
17+
from tests.utils.pubsub.utils import dense_connect
18+
19+
20+
@dataclasses.dataclass(frozen=True, slots=True)
21+
class GossipSubHarness:
22+
"""Typed wrapper around a batch of GossipSub-backed pubsub instances."""
23+
24+
pubsubs: tuple[Pubsub, ...]
25+
26+
@property
27+
def hosts(self) -> tuple[IHost, ...]:
28+
return tuple(ps.host for ps in self.pubsubs)
29+
30+
@property
31+
def routers(self) -> tuple[GossipSub, ...]:
32+
result: list[GossipSub] = []
33+
for ps in self.pubsubs:
34+
r = ps.router
35+
assert isinstance(r, GossipSub), f"Expected GossipSub, got {type(r)}"
36+
result.append(r)
37+
return tuple(result)
38+
39+
def __len__(self) -> int:
40+
return len(self.pubsubs)
41+
42+
43+
@asynccontextmanager
44+
async def gossipsub_nodes(n: int, **kwargs: Any) -> AsyncIterator[GossipSubHarness]:
45+
"""
46+
Create *n* GossipSub-backed pubsub nodes wrapped in a harness.
47+
48+
Usage::
49+
50+
async with gossipsub_nodes(3, heartbeat_interval=0.5) as h:
51+
h.pubsubs # tuple[Pubsub, ...]
52+
h.hosts # tuple[IHost, ...]
53+
h.routers # tuple[GossipSub, ...]
54+
"""
55+
async with PubsubFactory.create_batch_with_gossipsub(n, **kwargs) as pubsubs:
56+
yield GossipSubHarness(pubsubs=pubsubs)
57+
58+
59+
@asynccontextmanager
60+
async def connected_gossipsub_nodes(
61+
n: int, *, strict: bool = False, **kwargs: Any
62+
) -> AsyncIterator[GossipSubHarness]:
63+
"""
64+
Create *n* GossipSub nodes with dense connectivity.
65+
66+
By default this waits only until each node has observed one expected
67+
neighbour (fast path). Pass ``strict=True`` to wait until every node
68+
has observed every other expected peer — useful for topology-sensitive
69+
tests that assert exact peer counts or full fanout behaviour.
70+
"""
71+
peer_wait_timeout = kwargs.pop("peer_wait_timeout", 5.0)
72+
async with gossipsub_nodes(n, **kwargs) as harness:
73+
await dense_connect(harness.hosts)
74+
if n > 1:
75+
with trio.fail_after(peer_wait_timeout):
76+
if strict:
77+
for index, pubsub in enumerate(harness.pubsubs):
78+
for other_index, other_host in enumerate(harness.hosts):
79+
if other_index == index:
80+
continue
81+
await pubsub.wait_for_peer(other_host.get_id())
82+
else:
83+
for index, pubsub in enumerate(harness.pubsubs):
84+
target_host = harness.hosts[(index + 1) % n]
85+
await pubsub.wait_for_peer(target_host.get_id())
86+
yield harness
87+
88+
89+
@asynccontextmanager
90+
async def subscribed_mesh(
91+
topic: str, n: int, *, settle_time: float = 1.0, **kwargs: Any
92+
) -> AsyncIterator[GossipSubHarness]:
93+
"""
94+
Create *n* connected GossipSub nodes all subscribed to *topic*.
95+
96+
Waits *settle_time* seconds for mesh formation before yielding.
97+
"""
98+
async with connected_gossipsub_nodes(n, **kwargs) as harness:
99+
for ps in harness.pubsubs:
100+
await ps.subscribe(topic)
101+
# TODO(#378): replace fixed sleep with predicate-based mesh-ready polling
102+
await trio.sleep(settle_time)
103+
yield harness
104+
105+
106+
@pytest.fixture
107+
async def connected_gossipsub_pair() -> AsyncIterator[GossipSubHarness]:
108+
"""Fixture: two connected GossipSub nodes with default config."""
109+
async with connected_gossipsub_nodes(2) as harness:
110+
yield harness

tests/core/pubsub/test_dummyaccount_demo.py

Lines changed: 3 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
from collections.abc import (
2-
Callable,
3-
)
4-
import logging
5-
61
import pytest
72
import trio
83

@@ -12,69 +7,9 @@
127
from tests.utils.pubsub.dummy_account_node import (
138
DummyAccountNode,
149
)
15-
16-
logger = logging.getLogger(__name__)
17-
18-
19-
async def wait_for_convergence(
20-
nodes: tuple[DummyAccountNode, ...],
21-
check: Callable[[DummyAccountNode], bool],
22-
timeout: float = 10.0,
23-
poll_interval: float = 0.02,
24-
log_success: bool = False,
25-
raise_last_exception_on_timeout: bool = True,
26-
) -> None:
27-
"""
28-
Wait until all nodes satisfy the check condition.
29-
30-
Returns as soon as convergence is reached, otherwise raises TimeoutError.
31-
Convergence already guarantees all nodes satisfy the check, so callers need
32-
not run a second assertion pass after this returns.
33-
"""
34-
start_time = trio.current_time()
35-
36-
last_exception: Exception | None = None
37-
last_exception_node: int | None = None
38-
39-
while True:
40-
failed_indices: list[int] = []
41-
for i, node in enumerate(nodes):
42-
try:
43-
ok = check(node)
44-
except Exception as exc:
45-
ok = False
46-
last_exception = exc
47-
last_exception_node = i
48-
if not ok:
49-
failed_indices.append(i)
50-
51-
if not failed_indices:
52-
elapsed = trio.current_time() - start_time
53-
if log_success:
54-
logger.debug("Converged in %.3fs with %d nodes", elapsed, len(nodes))
55-
return
56-
57-
elapsed = trio.current_time() - start_time
58-
if elapsed > timeout:
59-
if raise_last_exception_on_timeout and last_exception is not None:
60-
# Preserve the underlying assertion/exception signal (and its message)
61-
# instead of hiding it behind a generic timeout.
62-
node_hint = (
63-
f" (node index {last_exception_node})"
64-
if last_exception_node is not None
65-
else ""
66-
)
67-
raise AssertionError(
68-
f"Convergence failed{node_hint}: {last_exception}"
69-
) from last_exception
70-
71-
raise TimeoutError(
72-
f"Convergence timeout after {elapsed:.2f}s. "
73-
f"Failed nodes: {failed_indices}. "
74-
f"(Hint: run with -s and pass log_success=True for timing logs)"
75-
)
76-
77-
await trio.sleep(poll_interval)
10+
from tests.utils.pubsub.wait import (
11+
wait_for_convergence,
12+
)
7813

7914

8015
async def perform_test(num_nodes, adjacency_map, action_func, assertion_func):
@@ -116,7 +51,6 @@ def _check_final(node: DummyAccountNode) -> bool:
11651
# Success, terminate pending tasks.
11752

11853

119-
@pytest.mark.trio
12054
async def test_simple_two_nodes():
12155
num_nodes = 2
12256
adj_map = {0: [1]}
@@ -130,7 +64,6 @@ def assertion_func(dummy_node):
13064
await perform_test(num_nodes, adj_map, action_func, assertion_func)
13165

13266

133-
@pytest.mark.trio
13467
async def test_simple_three_nodes_line_topography():
13568
num_nodes = 3
13669
adj_map = {0: [1], 1: [2]}
@@ -144,7 +77,6 @@ def assertion_func(dummy_node):
14477
await perform_test(num_nodes, adj_map, action_func, assertion_func)
14578

14679

147-
@pytest.mark.trio
14880
async def test_simple_three_nodes_triangle_topography():
14981
num_nodes = 3
15082
adj_map = {0: [1, 2], 1: [2]}
@@ -158,7 +90,6 @@ def assertion_func(dummy_node):
15890
await perform_test(num_nodes, adj_map, action_func, assertion_func)
15991

16092

161-
@pytest.mark.trio
16293
async def test_simple_seven_nodes_tree_topography():
16394
num_nodes = 7
16495
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
@@ -172,7 +103,6 @@ def assertion_func(dummy_node):
172103
await perform_test(num_nodes, adj_map, action_func, assertion_func)
173104

174105

175-
@pytest.mark.trio
176106
async def test_set_then_send_from_root_seven_nodes_tree_topography():
177107
num_nodes = 7
178108
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
@@ -197,7 +127,6 @@ def assertion_func(dummy_node):
197127
await perform_test(num_nodes, adj_map, action_func, assertion_func)
198128

199129

200-
@pytest.mark.trio
201130
async def test_set_then_send_from_different_leafs_seven_nodes_tree_topography():
202131
num_nodes = 7
203132
adj_map = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
@@ -216,7 +145,6 @@ def assertion_func(dummy_node):
216145
await perform_test(num_nodes, adj_map, action_func, assertion_func)
217146

218147

219-
@pytest.mark.trio
220148
async def test_simple_five_nodes_ring_topography():
221149
num_nodes = 5
222150
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
@@ -230,7 +158,6 @@ def assertion_func(dummy_node):
230158
await perform_test(num_nodes, adj_map, action_func, assertion_func)
231159

232160

233-
@pytest.mark.trio
234161
async def test_set_then_send_from_diff_nodes_five_nodes_ring_topography():
235162
num_nodes = 5
236163
adj_map = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]}
@@ -252,7 +179,6 @@ def assertion_func(dummy_node):
252179
await perform_test(num_nodes, adj_map, action_func, assertion_func)
253180

254181

255-
@pytest.mark.trio
256182
@pytest.mark.slow
257183
async def test_set_then_send_from_five_diff_nodes_five_nodes_ring_topography():
258184
num_nodes = 5

tests/utils/pubsub/wait.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""Polling helpers for pubsub test synchronization."""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Callable
6+
import inspect
7+
import logging
8+
from typing import TYPE_CHECKING
9+
10+
import trio
11+
12+
if TYPE_CHECKING:
13+
from tests.utils.pubsub.dummy_account_node import DummyAccountNode
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
async def wait_for(
19+
predicate: Callable[[], object],
20+
*,
21+
timeout: float = 10.0,
22+
poll_interval: float = 0.02,
23+
fail_msg: str = "",
24+
) -> None:
25+
"""
26+
Poll until *predicate()* returns a truthy value, or raise ``TimeoutError``.
27+
28+
Supports sync predicates, async predicates, and callables that return
29+
awaitables (e.g. ``lambda: some_async_fn()``). If the predicate raises
30+
an exception it is treated as falsy; on timeout the last such exception
31+
is chained to the ``TimeoutError``.
32+
"""
33+
start = trio.current_time()
34+
last_exc: Exception | None = None
35+
36+
while True:
37+
try:
38+
result = predicate()
39+
if inspect.isawaitable(result):
40+
result = await result
41+
if result:
42+
return
43+
except Exception as exc:
44+
last_exc = exc
45+
46+
elapsed = trio.current_time() - start
47+
if elapsed > timeout:
48+
msg = fail_msg or f"wait_for timed out after {elapsed:.2f}s"
49+
err = TimeoutError(msg)
50+
if last_exc is not None:
51+
raise err from last_exc
52+
raise err
53+
54+
await trio.sleep(poll_interval)
55+
56+
57+
async def wait_for_convergence(
58+
nodes: tuple[DummyAccountNode, ...],
59+
check: Callable[[DummyAccountNode], bool],
60+
timeout: float = 10.0,
61+
poll_interval: float = 0.02,
62+
log_success: bool = False,
63+
raise_last_exception_on_timeout: bool = True,
64+
) -> None:
65+
"""
66+
Wait until all *nodes* satisfy *check*.
67+
68+
Returns as soon as convergence is reached, otherwise raises
69+
``TimeoutError`` (or ``AssertionError`` when
70+
*raise_last_exception_on_timeout* is ``True`` and a node raised).
71+
72+
Preserves the API of the original inline helper from
73+
``test_dummyaccount_demo.py``.
74+
"""
75+
start_time = trio.current_time()
76+
77+
last_exception: Exception | None = None
78+
last_exception_node: int | None = None
79+
80+
while True:
81+
failed_indices: list[int] = []
82+
for i, node in enumerate(nodes):
83+
try:
84+
ok = check(node)
85+
except Exception as exc:
86+
ok = False
87+
last_exception = exc
88+
last_exception_node = i
89+
if not ok:
90+
failed_indices.append(i)
91+
92+
if not failed_indices:
93+
elapsed = trio.current_time() - start_time
94+
if log_success:
95+
logger.debug("Converged in %.3fs with %d nodes", elapsed, len(nodes))
96+
return
97+
98+
elapsed = trio.current_time() - start_time
99+
if elapsed > timeout:
100+
if raise_last_exception_on_timeout and last_exception is not None:
101+
node_hint = (
102+
f" (node index {last_exception_node})"
103+
if last_exception_node is not None
104+
else ""
105+
)
106+
raise AssertionError(
107+
f"Convergence failed{node_hint}: {last_exception}"
108+
) from last_exception
109+
110+
raise TimeoutError(
111+
f"Convergence timeout after {elapsed:.2f}s. "
112+
f"Failed nodes: {failed_indices}. "
113+
f"(Hint: run with -s and pass log_success=True for timing logs)"
114+
)
115+
116+
await trio.sleep(poll_interval)

0 commit comments

Comments
 (0)