Skip to content

Commit 75d0dfc

Browse files
authored
feat: use NIXL native P2P metadata exchange instead of centralized blob storage (#177)
Signed-off-by: Nicolas 'Pixel' Noble <nicolas@nobis-crew.org>
1 parent 9b19667 commit 75d0dfc

File tree

19 files changed

+806
-86
lines changed

19 files changed

+806
-86
lines changed

docs/ARCHITECTURE.md

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,15 @@ Key message types: `ModelProvider` (HuggingFace), `ModelStatus` (Downloading, Do
251251
| `GetMetadata` | `GetMetadataRequest` | `GetMetadataResponse` | Fetch full tensor metadata for one specific worker (MB-scale, on demand) |
252252
| `UpdateStatus` | `UpdateStatusRequest` | `UpdateStatusResponse` | Update per-worker lifecycle status (Initializing/Ready/Stale) |
253253

254-
Key message types: `SourceIdentity` (all fields affecting tensor layout compatibility), `WorkerMetadata` (rank, oneof backend_metadata, tensors, status), `TensorDescriptor` (name, addr, size, device_id, dtype), `SourceInstanceRef` (lightweight worker reference for listing).
254+
Key message types: `SourceIdentity` (all fields affecting tensor layout compatibility), `WorkerMetadata` (rank, oneof backend_metadata, tensors, status, P2P endpoint fields), `TensorDescriptor` (name, addr, size, device_id, dtype), `SourceInstanceRef` (lightweight worker reference for listing).
255+
256+
### p2p.proto - WorkerService (P2P, opt-in)
257+
258+
| RPC | Request | Response | Purpose |
259+
|-----|---------|----------|---------|
260+
| `GetTensorManifest` | `GetTensorManifestRequest` | `GetTensorManifestResponse` | Fetch tensor descriptors directly from a source worker |
261+
262+
Per-worker gRPC service started when `MX_P2P_METADATA=1`. Targets call this instead of fetching tensor descriptors from the central server. Validates `mx_source_id` to catch stale discovery.
255263

256264
See [`metadata.md`](metadata.md) for the full metadata architecture including storage schemas and coordination protocol.
257265

@@ -430,6 +438,7 @@ Loading precedence: CLI args > environment variables > config file > defaults.
430438
| `gds_transfer.py` | GPUDirect Storage availability check and transfer utilities |
431439
| `gds_loader.py` | `MxGdsLoader` - GDS-based model loader (direct file-to-GPU) |
432440
| `vllm_loader.py` | `MxModelLoader` - auto-detecting model loader (RDMA -> GDS -> disk) |
441+
| `worker_server.py` | `WorkerGrpcServer` - per-worker gRPC server for P2P tensor manifest exchange |
433442
| `vllm_worker.py` | `ModelExpressWorker` - custom vLLM worker class (use `--worker-cls=modelexpress.vllm_worker.ModelExpressWorker`) |
434443
| `types.py` | `TensorDescriptor`, `WorkerMetadata`, `GetMetadataResponse` dataclasses |
435444
| `p2p_pb2.py` / `p2p_pb2_grpc.py` | Generated protobuf/gRPC stubs |
@@ -452,10 +461,11 @@ Manages a NIXL agent and RDMA transfers for a single GPU worker:
452461

453462
| Method | Purpose |
454463
|--------|---------|
455-
| `__init__(agent_name, device_id)` | Create NIXL agent with UCX backend |
464+
| `__init__(agent_name, device_id, listen_port)` | Create NIXL agent with UCX backend; `listen_port` enables P2P listen thread |
456465
| `register_tensors(tensors)` | Register GPU tensors for RDMA, return serialized metadata |
457466
| `get_registered_descriptors()` | Return region descriptors (`MX_CONTIGUOUS_REG=1`) or tensor descriptors |
458-
| `receive_from_source(source_metadata, source_tensors, ...)` | Execute RDMA read transfer with optional coalescing |
467+
| `fetch_remote_and_wait(agent_name, ip, port)` | P2P: fetch remote NIXL metadata via listen thread (polls until loaded) |
468+
| `receive_from_source(source_metadata, source_tensors, ..., remote_agent_name)` | Execute RDMA read transfer; `remote_agent_name` skips `add_remote_agent` (P2P) |
459469
| `shutdown()` | Clean up NIXL agent and resources |
460470

461471
### vLLM Loader
@@ -558,10 +568,10 @@ graph TD
558568
### Flow
559569

560570
1. **Source loads**: Loads weights from disk (or GDS), runs `process_weights_after_loading()`
561-
2. **Source publishes**: Registers tensors with NIXL, calls `PublishMetadata(identity, worker, worker_id)` -> gets `mx_source_id` (status=INITIALIZING)
571+
2. **Source publishes**: Registers tensors with NIXL, calls `PublishMetadata(identity, worker, worker_id)` -> gets `mx_source_id` (status=INITIALIZING). In P2P mode (`MX_P2P_METADATA=1`), publishes only lightweight endpoint pointers and starts a `WorkerGrpcServer` for tensor manifest serving.
562572
3. **Heartbeat starts**: `HeartbeatThread` sends `UpdateStatus(READY)` every 30s, refreshing `updated_at`
563573
4. **Target discovers**: Calls `ListSources(identity, status=READY)`, filters by `worker_rank`
564-
5. **Target fetches on demand**: Calls `GetMetadata(mx_source_id, worker_id)` for the chosen candidate
574+
5. **Target fetches on demand**: Calls `GetMetadata(mx_source_id, worker_id)` for the chosen candidate. Auto-detects P2P mode if `worker_grpc_endpoint` is populated - fetches tensors from the source worker's `WorkerService` and NIXL metadata via the listen thread instead of from the central server.
565575
6. **Target transfers**: Executes RDMA reads from source; on `SourceTransferError` tries next candidate (max 3)
566576
7. **Target becomes source**: After receiving weights, publishes own metadata and starts its own heartbeat
567577
8. **Stale detection**: Server-side reaper marks workers STALE if `updated_at` > 90s old; GC deletes after 1 hour
@@ -579,6 +589,10 @@ See [`metadata.md`](metadata.md) for the full storage schema and debugging guide
579589
| `MX_SERVER_ADDRESS` | `localhost:8001` | Backward-compat alias for `MODEL_EXPRESS_URL` |
580590
| `MX_METADATA_BACKEND` | (required) | Metadata backend: `redis` or `kubernetes` |
581591
| `MX_CONTIGUOUS_REG` | `0` | Enable contiguous region registration (experimental) |
592+
| `MX_P2P_METADATA` | `0` | Enable P2P metadata exchange on source workers |
593+
| `MX_METADATA_PORT` | `0` | NIXL listen thread port for P2P metadata exchange |
594+
| `MX_WORKER_GRPC_PORT` | `0` | Worker gRPC port for P2P tensor manifest serving |
595+
| `MX_WORKER_HOST` | (auto-detect) | Override worker IP/hostname for P2P endpoints |
582596
| `MX_HEARTBEAT_INTERVAL_SECS` | `30` | Client heartbeat frequency |
583597
| `MX_HEARTBEAT_TIMEOUT_SECS` | `90` | Server reaper staleness threshold |
584598
| `MX_REAPER_SCAN_INTERVAL_SECS` | `30` | Server reaper scan frequency |

docs/DEPLOYMENT.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ ModelExpress supports GPU-to-GPU model weight transfers between vLLM instances u
230230
| `MX_SERVER_ADDRESS` | `localhost:8001` | Backward-compat alias for `MODEL_EXPRESS_URL` |
231231
| `MX_REGISTER_LOADERS` | `1` | Auto-register the mx loader with vLLM |
232232
| `MX_CONTIGUOUS_REG` | `0` | Contiguous region registration (experimental) |
233+
| `MX_P2P_METADATA` | `0` | Enable P2P metadata exchange (source workers only) |
234+
| `MX_METADATA_PORT` | `0` | NIXL listen thread port for P2P metadata exchange |
235+
| `MX_WORKER_GRPC_PORT` | `0` | Worker gRPC port for P2P tensor manifest serving |
236+
| `MX_WORKER_HOST` | (auto-detect) | Override worker IP/hostname for P2P endpoints |
233237
| `MX_STATUS_TTL_SECS` | `3600` | TTL for Redis metadata keys (seconds) |
234238
| `REDIS_URL` | `redis://localhost:6379` | Redis connection URL (Redis backend only) |
235239
| `MX_METADATA_NAMESPACE` | `default` | K8s namespace for CRD backend |
@@ -238,6 +242,18 @@ ModelExpress supports GPU-to-GPU model weight transfers between vLLM instances u
238242

239243
Each GPU worker publishes independently using its global rank (`torch.distributed.get_rank()`). No inter-worker coordination or barriers required.
240244

245+
### P2P Metadata Exchange (Opt-In)
246+
247+
By default, source workers publish full tensor metadata (NIXL blobs + tensor descriptors) to the central server. With `MX_P2P_METADATA=1`, source workers instead publish lightweight endpoint pointers and exchange metadata directly with targets:
248+
249+
- **NIXL agent blobs** exchanged via NIXL's native listen thread (`MX_METADATA_PORT`)
250+
- **Tensor descriptors** served by a per-worker gRPC `WorkerService` (`MX_WORKER_GRPC_PORT`)
251+
- **Central server** stores only endpoint addresses, not MB-scale metadata
252+
253+
Targets auto-detect which mode a source is using based on whether `worker_grpc_endpoint` is populated in the metadata. No configuration needed on the target side.
254+
255+
Set `MX_METADATA_PORT` and `MX_WORKER_GRPC_PORT` to fixed ports when running in K8s (port 0 picks an ephemeral port). Set `MX_WORKER_HOST` if the pod IP auto-detection doesn't produce a routable address.
256+
241257
### UCX/NIXL Tuning
242258

243259
| Variable | Recommended | Description |
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# Single-node vLLM deployment with P2P metadata exchange enabled.
5+
# Same as vllm-single-node.yaml but with MX_P2P_METADATA=1 on source workers.
6+
#
7+
# With P2P enabled, source workers exchange NIXL metadata and tensor manifests
8+
# directly with targets instead of routing through the central server. The
9+
# central server stores only lightweight endpoint pointers.
10+
#
11+
# Targets auto-detect P2P sources and need no special configuration.
12+
#
13+
# Prerequisites:
14+
# - ModelExpress server deployed (see ../../server/)
15+
# - PVC with model weights pre-downloaded
16+
# - kubectl create secret generic hf-token-secret --from-literal=HF_TOKEN=<token>
17+
apiVersion: v1
18+
kind: Service
19+
metadata:
20+
name: mx-vllm-p2p
21+
labels:
22+
app: mx-vllm-p2p
23+
spec:
24+
type: ClusterIP
25+
ports:
26+
- port: 8000
27+
targetPort: 8000
28+
name: http
29+
selector:
30+
app: mx-vllm-p2p
31+
---
32+
apiVersion: apps/v1
33+
kind: Deployment
34+
metadata:
35+
name: mx-vllm-p2p
36+
labels:
37+
app: mx-vllm-p2p
38+
spec:
39+
replicas: 1
40+
selector:
41+
matchLabels:
42+
app: mx-vllm-p2p
43+
template:
44+
metadata:
45+
labels:
46+
app: mx-vllm-p2p
47+
spec:
48+
serviceAccountName: modelexpress
49+
containers:
50+
- name: vllm
51+
image: nvcr.io/nvidian/dynamo-dev/modelexpress-client:latest
52+
imagePullPolicy: IfNotPresent
53+
securityContext:
54+
capabilities:
55+
add:
56+
- IPC_LOCK
57+
env:
58+
- name: VLLM_RPC_TIMEOUT
59+
value: "7200000"
60+
- name: HF_HUB_CACHE
61+
value: "/models"
62+
- name: MODEL_NAME
63+
value: "deepseek-ai/DeepSeek-V3"
64+
- name: VLLM_PLUGINS
65+
value: "modelexpress"
66+
- name: MX_SERVER_ADDRESS
67+
value: "modelexpress-server:8001"
68+
- name: MX_CONTIGUOUS_REG
69+
value: "0"
70+
# P2P metadata exchange: source workers serve NIXL metadata and
71+
# tensor manifests directly to targets instead of via the server.
72+
- name: MX_P2P_METADATA
73+
value: "1"
74+
# Fixed ports for NIXL listen thread and worker gRPC server.
75+
# Use fixed ports in K8s so they can be reached across pods.
76+
- name: MX_METADATA_PORT
77+
value: "5555"
78+
- name: MX_WORKER_GRPC_PORT
79+
value: "6555"
80+
- name: NIXL_LOG_LEVEL
81+
value: "INFO"
82+
- name: UCX_LOG_LEVEL
83+
value: "INFO"
84+
- name: UCX_TLS
85+
value: "rc_x,rc,dc_x,dc,cuda_copy"
86+
- name: UCX_RNDV_SCHEME
87+
value: "get_zcopy"
88+
- name: UCX_RNDV_THRESH
89+
value: "0"
90+
- name: POD_IP
91+
valueFrom:
92+
fieldRef:
93+
fieldPath: status.podIP
94+
- name: NODE_NAME
95+
valueFrom:
96+
fieldRef:
97+
fieldPath: spec.nodeName
98+
- name: POD_NAMESPACE
99+
valueFrom:
100+
fieldRef:
101+
fieldPath: metadata.namespace
102+
- name: HF_TOKEN
103+
valueFrom:
104+
secretKeyRef:
105+
name: hf-token-secret
106+
key: HF_TOKEN
107+
args:
108+
- --model
109+
- $(MODEL_NAME)
110+
- --load-format
111+
- mx
112+
- --tensor-parallel-size
113+
- "8"
114+
- --enable-expert-parallel
115+
resources:
116+
limits:
117+
nvidia.com/gpu: "8"
118+
rdma/ib: "8"
119+
requests:
120+
nvidia.com/gpu: "8"
121+
rdma/ib: "8"
122+
memory: "200Gi"
123+
cpu: "16"
124+
volumeMounts:
125+
- name: shm
126+
mountPath: /dev/shm
127+
- name: model-cache-block
128+
mountPath: /models
129+
130+
volumes:
131+
- name: shm
132+
emptyDir:
133+
medium: Memory
134+
sizeLimit: 64Gi
135+
- name: model-cache-block
136+
persistentVolumeClaim:
137+
claimName: model-cache-block
138+
imagePullSecrets:
139+
- name: nvcr-imagepullsecret

examples/p2p_transfer_k8s/server/kubernetes_backend/crd-modelmetadata.yaml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ spec:
5151
- none
5252
nixlMetadata:
5353
type: string
54-
format: byte
5554
description: Base64-encoded NIXL agent metadata blob
5655
transferEngineSessionId:
5756
type: string
@@ -62,6 +61,15 @@ spec:
6261
tensorConfigMap:
6362
type: string
6463
description: Name of ConfigMap containing tensor descriptors
64+
metadataEndpoint:
65+
type: string
66+
description: P2P NIXL listen thread endpoint (host:port)
67+
agentName:
68+
type: string
69+
description: P2P NIXL agent name for remote identification
70+
workerGrpcEndpoint:
71+
type: string
72+
description: P2P worker gRPC endpoint for tensor manifest (host:port)
6573
status:
6674
type: string
6775
description: Worker lifecycle status

modelexpress_client/python/modelexpress/heartbeat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
worker_id: str,
5151
worker_rank: int,
5252
nixl_manager: NixlTransferManager,
53-
53+
5454
):
5555
self._mx_client = mx_client
5656
self._mx_source_id = mx_source_id

modelexpress_client/python/modelexpress/nixl_transfer.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,22 @@ class NixlTransferManager:
5353
device_id: GPU device ID for this worker
5454
"""
5555

56-
def __init__(self, agent_name: str, device_id: int):
56+
def __init__(self, agent_name: str, device_id: int, listen_port: int | None = None):
5757
self._agent_name = agent_name
5858
self._device_id = device_id
59+
self._listen_port = listen_port
5960

6061
self._agent: Any = None
6162
self._metadata: bytes = b""
6263
self._tensor_descriptors: list[TensorDescriptor] = []
6364
self._tensors: dict[str, torch.Tensor] = {}
6465
self._registered_regions: list[tuple[int, int]] | None = None
6566

67+
@property
68+
def agent_name(self) -> str:
69+
"""Get NIXL agent name."""
70+
return self._agent_name
71+
6672
@property
6773
def nixl_metadata(self) -> bytes:
6874
"""Get NIXL metadata for this agent."""
@@ -83,7 +89,19 @@ def initialize(self) -> None:
8389

8490
torch.cuda.set_device(self._device_id)
8591

86-
config = nixl_agent_config(backends=["UCX"]) if nixl_agent_config else None
92+
if self._listen_port is not None and nixl_agent_config:
93+
config = nixl_agent_config(
94+
backends=["UCX"],
95+
enable_listen_thread=True,
96+
listen_port=self._listen_port,
97+
)
98+
logger.info(
99+
f"NIXL listen thread enabled on port {self._listen_port}"
100+
)
101+
elif nixl_agent_config:
102+
config = nixl_agent_config(backends=["UCX"])
103+
else:
104+
config = None
87105
self._agent = NixlAgent(self._agent_name, config)
88106
logger.info(f"NIXL agent '{self._agent_name}' created on device {self._device_id}")
89107

@@ -227,21 +245,59 @@ def _find_contiguous_regions(
227245

228246
return regions
229247

248+
def fetch_remote_and_wait(
249+
self,
250+
remote_agent_name: str,
251+
ip: str,
252+
port: int,
253+
timeout_seconds: float = 120.0,
254+
) -> None:
255+
"""Fetch remote NIXL agent metadata via the P2P listen thread.
256+
257+
Initiates an async fetch and polls until the remote agent's metadata
258+
is loaded locally. Used in P2P mode instead of add_remote_agent().
259+
"""
260+
if self._agent is None:
261+
raise RuntimeError("NIXL agent not initialized")
262+
263+
logger.info(
264+
f"Fetching remote metadata from {remote_agent_name} at {ip}:{port}"
265+
)
266+
self._agent.fetch_remote_metadata(remote_agent_name, ip, port)
267+
268+
start = time.perf_counter()
269+
while True:
270+
if time.perf_counter() - start >= timeout_seconds:
271+
raise TimeoutError(
272+
f"Timed out waiting for remote metadata from "
273+
f"{remote_agent_name} at {ip}:{port}"
274+
)
275+
if self._agent.check_remote_metadata(remote_agent_name):
276+
logger.info(
277+
f"Remote metadata loaded for {remote_agent_name} "
278+
f"({time.perf_counter() - start:.2f}s)"
279+
)
280+
return
281+
time.sleep(0.01)
282+
230283
def receive_from_source(
231284
self,
232285
source_metadata: bytes,
233286
source_tensors: list[TensorDescriptor],
234287
timeout_seconds: float | None = None,
235288
coalesce_transfers: bool = True,
289+
remote_agent_name: str | None = None,
236290
) -> tuple[int, int, float]:
237291
"""
238292
Receive weights from a remote source via NIXL RDMA.
239293
240294
Args:
241-
source_metadata: NIXL metadata from the source agent
295+
source_metadata: NIXL metadata from the source agent (unused if remote_agent_name set)
242296
source_tensors: Tensor descriptors from the source
243297
timeout_seconds: Maximum time to wait for transfer (None for no timeout)
244298
coalesce_transfers: If True, coalesce contiguous memory regions (optimization)
299+
remote_agent_name: If set, use this pre-loaded agent (P2P mode) instead of
300+
calling add_remote_agent with source_metadata (centralized mode)
245301
246302
Returns:
247303
Tuple of (total_bytes, total_tensors, duration)
@@ -252,9 +308,16 @@ def receive_from_source(
252308
start_time = time.perf_counter()
253309
torch.cuda.set_device(self._device_id)
254310

255-
# Add remote agent
256-
remote_agent_name = self._agent.add_remote_agent(source_metadata)
257-
logger.info(f"Added remote agent {remote_agent_name}")
311+
if remote_agent_name is None:
312+
add_start = time.perf_counter()
313+
remote_agent_name = self._agent.add_remote_agent(source_metadata)
314+
add_time = time.perf_counter() - add_start
315+
logger.info(
316+
f"[TIMING] add_remote_agent: {add_time:.3f}s "
317+
f"(agent={remote_agent_name}, blob={len(source_metadata)} bytes)"
318+
)
319+
else:
320+
logger.info(f"Using pre-loaded remote agent {remote_agent_name}")
258321

259322
# Check if source is sending region descriptors (MX_CONTIGUOUS_REG=1 on source)
260323
is_region_transfer = (

0 commit comments

Comments
 (0)