Skip to content

Question on using NIXL+UCX+CUDA on machines with 2 CX7 and 8 GPUs #11259

@Morifolium

Description

@Morifolium

Describe the bug

I am using NIXL and UCX for KV cache transfer between two servers and have written some test code. Each server has 8 A100 GPUs and 2 CX7 NICs, where every 4 A100 GPUs and 1 CX7 NIC are under the same PCIe switch. The test results show that when transferring data from card 0 to card 7, or from card 7 to card 0, the bandwidth drops to only half of the CX7's maximum bandwidth. In the UCX_LOG_LEVEL=data logs, it shows that each transfer attempts to use mlx5_0 instead of the affinity-matched NIC.

Steps to Reproduce


#!/usr/bin/env python3

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
import time

import torch

from nixl._api import nixl_agent, nixl_agent_config


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--ip", type=str, required=True)
    parser.add_argument("--port", type=int, default=5555)
    parser.add_argument("--use_cuda", type=int, default=-1)
    parser.add_argument(
        "--mode",
        type=str,
        default="initiator",
        help="Local IP in target, peer IP (target's) in initiator",
    )
    parser.add_argument("--log", type=bool, default=False)
    parser.add_argument("--warmup", type=int, default=100, help="Number of warmup iterations")
    parser.add_argument("--iter", type=int, default=100, help="Number of timed iterations for bandwidth measurement")
    parser.add_argument("--size", type=int, default=1048576, help="Tensor size in bytes")
    return parser.parse_args()


def create_tensor(size_bytes, mode, use_cuda):
    """Create a 1D float32 tensor of specified size in bytes"""
    num_elements = size_bytes // 4  # float32 is 4 bytes
    if mode == "target":
        tensor = torch.ones(num_elements, dtype=torch.float32)
    else:
        tensor = torch.zeros(num_elements, dtype=torch.float32)
    if use_cuda >= 0:
        tensor = tensor.to(f"cuda:{use_cuda}")
    return tensor


def setup_connection(agent, tensor, mode, args, target_ip, target_port):
    """Setup connection and exchange descriptors between agents. Returns descriptors for transfer."""
    # Register memory once
    reg_descs = agent.register_memory(tensor)
    if not reg_descs:
        print("Memory registration failed.")
        return None, None

    # Build local transfer descriptors
    local_rows = [tensor]
    local_descs = agent.get_xfer_descs(local_rows)
    if not local_descs:
        print("Failed to build local transfer descriptors.")
        agent.deregister_memory(reg_descs)
        return None, None

    if mode == "target":
        # Target: serialize and send descriptors, wait for initiator metadata
        target_desc_str = agent.get_serialized_descs(local_descs)
        
        ready = False
        while not ready:
            ready = agent.check_remote_metadata("initiator")
        agent.send_notif("initiator", target_desc_str)
        
        return local_descs, reg_descs

    else:
        # Initiator: exchange metadata and receive target descriptors
        agent.fetch_remote_metadata("target", target_ip, target_port)
        agent.send_local_metadata(target_ip, target_port)

        notifs = agent.get_new_notifs()
        while len(notifs) == 0:
            notifs = agent.get_new_notifs()
        target_descs = agent.deserialize_descs(notifs["target"][0])

        ready = False
        while not ready:
            ready = agent.check_remote_metadata("target")
        
        return (local_descs, target_descs), reg_descs


def run_bandwidth_test(agent, descriptors, mode, args, reg_descs):
    """Run bandwidth test using pre-established connection and descriptors."""
    if mode == "target":
        local_descs = descriptors
        # Target: wait for completion notification
        while True:
            notifs = agent.get_new_notifs()
            if "initiator" in notifs and b"Done_reading" in notifs["initiator"]:
                break
        return None

    else:
        local_descs, target_descs = descriptors
        
        # Warmup iterations (not timed)
        for _ in range(args.warmup):
            xfer_handle = agent.initialize_xfer(
                "WRITE", local_descs, target_descs, "target"
            )
            state = agent.transfer(xfer_handle)
            if state == "ERR":
                print("Posting transfer failed during warmup.")
                return None
            while True:
                state = agent.check_xfer_state(xfer_handle)
                if state in ["DONE", "ERR"]:
                    break
            if state == "ERR":
                print("Transfer got to Error state during warmup.")
                return None
            agent.release_xfer_handle(xfer_handle)

        # Timed iterations
        xfer_handle = agent.initialize_xfer(
            "WRITE", local_descs, target_descs, "target"
        )
        start_time = time.perf_counter()
        for _ in range(args.iter):
            state = agent.transfer(xfer_handle)
            if state == "ERR":
                print("Posting transfer failed during timed run.")
                return None
            while True:
                state = agent.check_xfer_state(xfer_handle)
                if state in ["DONE", "ERR"]:
                    break
            if state == "ERR":
                print("Transfer got to Error state during timed run.")
                return None
        end_time = time.perf_counter()
        agent.release_xfer_handle(xfer_handle)
        agent.send_notif("target", b"Done_reading")

        # Calculate bandwidth
        total_bytes = args.size  # already in bytes
        total_time = end_time - start_time
        bandwidth_gbps = (total_bytes * args.iter) / (total_time * 1024 * 1024 * 1024)
        bandwidth_mbps = (total_bytes * args.iter) / (total_time * 1024 * 1024)

        return {
            "size_bytes": total_bytes,
            "iterations": args.iter,
            "total_time_sec": total_time,
            "bandwidth_gbps": bandwidth_gbps,
            "bandwidth_mbps": bandwidth_mbps
        }


def cleanup(agent, mode, reg_descs, target_ip, target_port):
    """Cleanup resources: deregister memory and remove remote agent."""
    agent.deregister_memory(reg_descs)
    if mode == "initiator":
        agent.remove_remote_agent("target")
        agent.invalidate_local_metadata(target_ip, target_port)


if __name__ == "__main__":
    args = parse_args()

    # Initiator uses ephemeral port, target uses specified port
    listen_port = args.port if args.mode == "target" else 0

    # Create single agent instance (global for this script)
    config = nixl_agent_config(True, True, listen_port)
    agent = nixl_agent(args.mode, config)

    # Create tensor
    tensor = create_tensor(args.size, args.mode, args.use_cuda)

    # Setup connection and exchange descriptors (once, outside test loop)
    descriptors, reg_descs = setup_connection(agent, tensor, args.mode, args, args.ip, args.port)
    if descriptors is None:
        print("Connection setup failed.")
        exit(1)

    # Print header (initiator only)
    if args.mode == "initiator":
        print(f"{'Size (Bytes)':>12} {'Size (MB)':>10} {'Iterations':>10} {'Time (s)':>10} {'BW (MB/s)':>12} {'BW (GB/s)':>10}")
        print("-" * 70)
        print(f"Testing size: {args.size} bytes ({args.size / 1048576:.2f} MB)")

    # Run bandwidth test (single size, single call)
    result = run_bandwidth_test(agent, descriptors, args.mode, args, reg_descs)

    # Print result (initiator only)
    if result and args.mode == "initiator":
        size_mb = result["size_bytes"] / 1048576
        print(f"{result['size_bytes']:12d} {size_mb:10.2f} {result['iterations']:10d} "
              f"{result['total_time_sec']:10.4f} {result['bandwidth_mbps']:12.2f} {result['bandwidth_gbps']:10.3f}")

    # Cleanup
    cleanup(agent, args.mode, reg_descs, args.ip, args.port)
    
    print("Bandwidth test complete.")


    '''
target first:
python3 test.py --ip ${target_ip} --port ${target_port} --use_cuda 0 --warmup 100 --iter 100 --start_size 524288 --max_size 536870912 --mode target


initiator:
python3 test.py --ip ${target_ip} --port ${target_port} --use_cuda 7 --warmup 100 --iter 100 --start_size 524288 --max_size 536870912 --mode initiator
'''

Setup and versions

2 Node
Each 8 * A100 + 2 * CX7
NIXL 0.9.0
UCX 1.18.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions