commit 031bc48953b9c2d721f5505a80f93bb25cf7b236 Author: autoscriptlabs Date: Fri Jan 9 14:09:33 2026 -0500 Initial release: NCCL Mesh Plugin for direct-connect RDMA topologies - Enables NCCL over multi-subnet mesh topologies - 8+ GB/s bandwidth over 100Gbps RDMA - Successfully tested with distributed LLM inference (Mistral-7B) - Custom subnet-aware NIC selection - Background handshake thread for deadlock-free connection setup diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..14fac91 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a2e1ac5 --- /dev/null +++ b/Makefile @@ -0,0 +1,56 @@ +# NCCL Mesh Plugin Makefile + +CC = gcc +CFLAGS = -Wall -Wextra -O2 -fPIC -g +CFLAGS += -I. -I./include +LDFLAGS = -shared -libverbs -lpthread + +# Target +TARGET = libnccl-net.so +TARGET_MESH = libnccl-net-mesh.so + +# Sources +SRCS = src/mesh_plugin.c +OBJS = $(SRCS:.c=.o) + +# Default target +all: $(TARGET) $(TARGET_MESH) + +$(TARGET): $(OBJS) + $(CC) $(OBJS) -o $@ $(LDFLAGS) + +$(TARGET_MESH): $(TARGET) + ln -sf $(TARGET) $(TARGET_MESH) + +%.o: %.c + $(CC) $(CFLAGS) -c $< -o $@ + +# Install to a standard location +PREFIX ?= /usr/local +install: all + install -d $(PREFIX)/lib + install -m 755 $(TARGET) $(PREFIX)/lib/ + ln -sf $(TARGET) $(PREFIX)/lib/$(TARGET_MESH) + +# Clean +clean: + rm -f $(OBJS) $(TARGET) $(TARGET_MESH) + +# Test build (requires libibverbs-dev) +test-deps: + @echo "Checking dependencies..." + @pkg-config --exists libibverbs || (echo "ERROR: libibverbs-dev not found" && exit 1) + @echo "All dependencies found." + +# Debug build +debug: CFLAGS += -DDEBUG -g3 -O0 +debug: clean all + +# Print configuration +info: + @echo "CC = $(CC)" + @echo "CFLAGS = $(CFLAGS)" + @echo "LDFLAGS = $(LDFLAGS)" + @echo "TARGET = $(TARGET)" + +.PHONY: all clean install test-deps debug info diff --git a/README.md b/README.md new file mode 100644 index 0000000..a769a6d --- /dev/null +++ b/README.md @@ -0,0 +1,244 @@ +# NCCL Mesh Plugin + +**Custom NCCL network plugin enabling distributed ML over direct-connect RDMA mesh topologies.** + +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) + +## 🎯 What This Does + +This plugin enables NCCL (NVIDIA Collective Communications Library) to work with **direct-connect mesh topologies** where each node pair is on a different subnet. Standard NCCL plugins assume either: +- A switched InfiniBand fabric (all nodes on same subnet) +- TCP/IP networking (slow, high latency) + +Neither works for direct-cabled RDMA meshes. This plugin does. + +## πŸ”§ The Problem We Solved + +``` + β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ Spark-A β”‚ + β”‚ (titanic) β”‚ + β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ + 192.168.101.x β”‚ 192.168.100.x + (100Gbps) β”‚ (100Gbps) + β”Œβ”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β” + β”‚ β”‚ + β”Œβ”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β” + β”‚ Spark-B β”‚ β”‚ Spark-C β”‚ + β”‚ (iceberg) β”‚ β”‚(carpathia)β”‚ + β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ + β”‚ β”‚ + β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜ + 192.168.102.x + (100Gbps) +``` + +**Three DGX Spark workstations** connected in a triangle mesh with direct 100Gbps RDMA cables. Each link is on a **different subnet** - a configuration NVIDIA never intended to support. + +## πŸš€ Results + +| Metric | Value | +|--------|-------| +| Effective Bandwidth | **8+ GB/s** | +| Line Rate Utilization | ~64% | +| Topology | 3-node triangle mesh | +| Link Speed | 100 Gbps per link | + +Successfully ran **distributed LLM inference** (Mistral-7B) across all 3 nodes using NCCL over this custom topology. + +## πŸ—οΈ Architecture + +### Key Innovations + +1. **Multi-Address Handle Exchange** + - Each node advertises ALL its subnet IPs in the NCCL handle + - Connector searches for reachable addresses by subnet matching + +2. **Subnet-Aware NIC Selection** + - `connect()` finds the local NIC on the same subnet as the peer + - Automatic routing without IP forwarding or bridges + +3. **Background Handshake Thread** + - Eliminates deadlock when both ranks call `connect()` simultaneously + - TCP-based QP info exchange runs asynchronously + +4. **Bidirectional QP Exchange** + - Each connection creates fresh Queue Pairs on both sides + - No QP reuse across multiple NCCL channels + +### RDMA Implementation + +- Raw InfiniBand Verbs API (libibverbs) +- Reliable Connected (RC) Queue Pairs +- RoCE v2 over Ethernet +- Host memory staging (GPUβ†’Hostβ†’RDMAβ†’Hostβ†’GPU) + +## πŸ“¦ Installation + +### Prerequisites + +```bash +# Ubuntu/Debian +sudo apt-get install libibverbs-dev librdmacm-dev + +# Verify RDMA devices +ibv_devices +``` + +### Build + +```bash +git clone https://github.com/yourusername/nccl-mesh-plugin.git +cd nccl-mesh-plugin +make +``` + +### Use + +```bash +export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH +export NCCL_NET_PLUGIN=mesh +export NCCL_DEBUG=INFO # or WARN for less output + +# Run your distributed job +python your_distributed_script.py +``` + +## πŸ§ͺ Testing + +### Basic All-Reduce Test + +```python +import torch +import torch.distributed as dist + +dist.init_process_group('nccl', rank=RANK, world_size=3, + init_method='tcp://MASTER_IP:29500') + +t = torch.ones(1000, device='cuda') +dist.all_reduce(t) +print(f'Result: {t[0]}') # Should print 3.0 + +dist.destroy_process_group() +``` + +### Bandwidth Benchmark + +```python +import torch +import torch.distributed as dist +import time + +dist.init_process_group('nccl', rank=RANK, world_size=3, + init_method='tcp://MASTER_IP:29500') + +t = torch.ones(1024*1024*64, device='cuda') # 256MB + +# Warmup +for _ in range(5): + dist.all_reduce(t) +torch.cuda.synchronize() + +# Benchmark +start = time.time() +for _ in range(20): + dist.all_reduce(t) +torch.cuda.synchronize() +elapsed = time.time() - start + +print(f'Bandwidth: {(256*20/1024)/elapsed:.2f} GB/s') +``` + +## πŸ”¬ How It Works + +### Connection Flow + +``` +Rank 0 (listen) Rank 1 (connect) + β”‚ β”‚ + β–Ό β”‚ + listen() β”‚ + β”œβ”€ Create QPs on ALL NICs β”‚ + β”œβ”€ Start handshake thread β”‚ + β”œβ”€ Return handle with all IPs β”‚ + β”‚ β”‚ + │◄──────── handle exchange ────────►│ + β”‚ β”‚ + β”‚ β–Ό + β”‚ connect() + β”‚ β”œβ”€ Find matching subnet + β”‚ β”œβ”€ Create QP on that NIC + β”‚ β”œβ”€ TCP handshake ──────────►│ + β”‚ β”‚ β”‚ + │◄────────────────────────────────────────── QP info ────── + β”‚ β”‚ β”‚ + β–Ό β–Ό β–Ό + accept() Connect QP [handshake thread] + β”œβ”€ Get QP from queue to peer's QP β”œβ”€ Accept TCP + └─ Return recv_comm β”‚ β”œβ”€ Create new QP + β”‚ β”œβ”€ Connect QPs + β”‚ └─ Queue for accept() + β”‚ + β”Œβ”€β”€β”€β”€β”΄β”€β”€β”€β”€β” + β”‚ RDMA OK β”‚ + β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +### Subnet Matching + +```c +// For each peer address in handle +for (int i = 0; i < handle->num_addrs; i++) { + uint32_t peer_ip = handle->addrs[i].ip; + + // Find local NIC on same subnet + for (int j = 0; j < num_nics; j++) { + if ((peer_ip & nic[j].netmask) == nic[j].subnet) { + // Found matching NIC! + selected_nic = &nic[j]; + break; + } + } +} +``` + +## βš™οΈ Configuration + +| Environment Variable | Default | Description | +|---------------------|---------|-------------| +| `NCCL_NET_PLUGIN` | - | Set to `mesh` to use this plugin | +| `NCCL_DEBUG` | `WARN` | Set to `INFO` for detailed logs | +| `NCCL_MESH_GID_INDEX` | `3` | RoCE GID index to use | +| `NCCL_MESH_DEBUG` | `0` | Enable plugin debug output | + +## 🚧 Limitations + +- **Host memory staging**: GPU memory goes through host (no GPUDirect RDMA yet) +- **Single QP per connection**: No multi-rail aggregation +- **No relay routing**: Non-adjacent nodes can't communicate (fine for fully-connected mesh) +- **RoCE v2 only**: No InfiniBand support (Ethernet only) + +## πŸ—ΊοΈ Roadmap + +- [ ] GPUDirect RDMA support (bypass host memory) +- [ ] Multi-QP per connection for higher bandwidth +- [ ] Adaptive routing for partial meshes +- [ ] Performance tuning (inline data, signaling) + +## πŸ“š References + +- [NCCL Documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/) +- [RDMA Aware Networks Programming User Manual](https://www.mellanox.com/related-docs/prod_software/RDMA_Aware_Programming_user_manual.pdf) +- [InfiniBand Verbs API](https://github.com/linux-rdma/rdma-core) + +## πŸ“„ License + +MIT License - see [LICENSE](LICENSE) file. + +## πŸ™ Acknowledgments + +Built to connect three DGX Spark workstations that NVIDIA never intended to be clustered. Sometimes the best solutions come from ignoring "supported configurations." + +--- + +*"The future of distributed AI computing is here."* - Mistral-7B, running on this very plugin diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md new file mode 100644 index 0000000..9400c51 --- /dev/null +++ b/docs/ARCHITECTURE.md @@ -0,0 +1,337 @@ +# NCCL Mesh Plugin Architecture + +This document provides a deep dive into the architecture and implementation of the NCCL Mesh Plugin. + +## Overview + +The NCCL Mesh Plugin is a custom network transport that enables NCCL to work with direct-connect RDMA mesh topologies where each node pair is on a different subnet. This is a configuration that standard NCCL plugins cannot handle. + +## The Problem + +### Standard NCCL Networking + +NCCL's built-in network plugins assume one of two scenarios: + +1. **InfiniBand Fabric**: All nodes connected through IB switches, sharing a single subnet +2. **TCP/IP Sockets**: Standard IP networking with routing + +### Our Topology + +``` + Node A (192.168.100.2, 192.168.101.2) + / \ + 192.168.100.x 192.168.101.x + / \ + Node C Node B +(192.168.100.3, (192.168.101.3, + 192.168.102.3) 192.168.102.2) + \ / + \ 192.168.102.x / + \ / + \--------------/ +``` + +Each link is on a **different subnet**: +- A↔B: 192.168.101.0/24 +- A↔C: 192.168.100.0/24 +- B↔C: 192.168.102.0/24 + +This means: +- No single IP can reach all peers +- Standard IB plugin fails (expects single subnet) +- TCP socket plugin would need IP routing (adds latency) + +## Solution Architecture + +### Key Insight + +Each node has **multiple NICs**, each on a different subnet. When connecting to a peer, we must: +1. Determine which subnet the peer is on +2. Use the local NIC on that same subnet +3. Establish RDMA connection over that specific NIC pair + +### Handle Structure + +The NCCL handle is expanded to advertise **all** local addresses: + +```c +struct mesh_handle { + uint32_t magic; // Validation + uint8_t num_addrs; // Number of addresses + uint16_t handshake_port; // TCP port for QP exchange + + struct mesh_addr_entry { + uint32_t ip; // IP address (network order) + uint32_t mask; // Subnet mask + uint32_t qp_num; // Queue Pair number + uint8_t nic_idx; // Index into local NIC array + } addrs[MESH_MAX_ADDRS]; +}; +``` + +### Connection Flow + +#### Phase 1: Listen + +```c +ncclResult_t mesh_listen(int dev, void *handle, void **listenComm) { + // 1. Create QPs on ALL local NICs + for (int i = 0; i < num_nics; i++) { + create_qp_on_nic(&nics[i]); + } + + // 2. Start background handshake thread + pthread_create(&thread, handshake_thread_func, lcomm); + + // 3. Fill handle with ALL addresses + for (int i = 0; i < num_nics; i++) { + handle->addrs[i].ip = nics[i].ip_addr; + handle->addrs[i].mask = nics[i].netmask; + handle->addrs[i].qp_num = qps[i]->qp_num; + } +} +``` + +#### Phase 2: Connect + +```c +ncclResult_t mesh_connect(int dev, void *handle, void **sendComm) { + // 1. Search peer's addresses for reachable one + for (int i = 0; i < handle->num_addrs; i++) { + uint32_t peer_subnet = handle->addrs[i].ip & handle->addrs[i].mask; + + // Find local NIC on same subnet + for (int j = 0; j < num_local_nics; j++) { + if (local_nics[j].subnet == peer_subnet) { + selected_nic = &local_nics[j]; + selected_peer_addr = &handle->addrs[i]; + break; + } + } + } + + // 2. Create QP on selected NIC + create_qp_on_nic(selected_nic); + + // 3. Exchange QP info via TCP handshake + send_handshake(peer_ip, peer_port, &local_qp_info, &remote_qp_info); + + // 4. Connect QP to peer's QP + connect_qp(local_qp, remote_qp_info); +} +``` + +#### Phase 3: Accept + +```c +ncclResult_t mesh_accept(void *listenComm, void **recvComm) { + // Get pre-connected QP from handshake thread's queue + pthread_mutex_lock(&queue_mutex); + while (queue_empty) { + pthread_cond_wait(&queue_cond, &queue_mutex); + } + entry = dequeue(); + pthread_mutex_unlock(&queue_mutex); + + // Return the ready connection + rcomm->qp = entry->local_qp; + rcomm->nic = entry->nic; +} +``` + +### Background Handshake Thread + +The handshake thread solves a critical deadlock problem: + +**Without thread:** +``` +Rank 0: connect() β†’ TCP connect to Rank 1 β†’ blocks waiting for accept() +Rank 1: connect() β†’ TCP connect to Rank 0 β†’ blocks waiting for accept() +// DEADLOCK: Neither can call accept() because both stuck in connect() +``` + +**With thread:** +``` +Rank 0: listen() starts thread β†’ thread waits for TCP connections +Rank 1: listen() starts thread β†’ thread waits for TCP connections +Rank 0: connect() β†’ TCP connects to Rank 1's thread β†’ gets response β†’ returns +Rank 1: connect() β†’ TCP connects to Rank 0's thread β†’ gets response β†’ returns +Rank 0: accept() β†’ gets QP from queue (filled by thread) β†’ returns +Rank 1: accept() β†’ gets QP from queue (filled by thread) β†’ returns +// SUCCESS: Thread handles incoming connections asynchronously +``` + +### RDMA Queue Pair Setup + +Each connection requires proper QP state transitions: + +``` +RESET β†’ INIT β†’ RTR β†’ RTS +``` + +```c +int mesh_connect_qp(struct ibv_qp *qp, struct mesh_nic *nic, + struct mesh_handle *remote) { + // RESET β†’ INIT + qp_attr.qp_state = IBV_QPS_INIT; + qp_attr.pkey_index = 0; + qp_attr.port_num = nic->port_num; + qp_attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_REMOTE_READ | + IBV_ACCESS_LOCAL_WRITE; + ibv_modify_qp(qp, &qp_attr, ...); + + // INIT β†’ RTR (Ready to Receive) + qp_attr.qp_state = IBV_QPS_RTR; + qp_attr.path_mtu = IBV_MTU_4096; + qp_attr.dest_qp_num = remote->qp_num; + qp_attr.rq_psn = remote->psn; + qp_attr.ah_attr.dlid = remote->lid; // 0 for RoCE + qp_attr.ah_attr.grh.dgid = remote->gid; // Peer's GID + ibv_modify_qp(qp, &qp_attr, ...); + + // RTR β†’ RTS (Ready to Send) + qp_attr.qp_state = IBV_QPS_RTS; + qp_attr.sq_psn = local_psn; + qp_attr.timeout = 14; + qp_attr.retry_cnt = 7; + qp_attr.rnr_retry = 7; + ibv_modify_qp(qp, &qp_attr, ...); +} +``` + +### Data Transfer + +#### Send Path + +```c +ncclResult_t mesh_isend(void *sendComm, void *data, int size, + void *mhandle, void **request) { + struct ibv_send_wr wr = { + .wr_id = (uint64_t)req, + .sg_list = &sge, + .num_sge = 1, + .opcode = IBV_WR_SEND, + .send_flags = IBV_SEND_SIGNALED, + }; + + sge.addr = (uint64_t)data; + sge.length = size; + sge.lkey = mr->lkey; + + ibv_post_send(comm->qp, &wr, &bad_wr); +} +``` + +#### Receive Path + +```c +ncclResult_t mesh_irecv(void *recvComm, int n, void **data, + int *sizes, void **mhandles, void **request) { + struct ibv_recv_wr wr = { + .wr_id = (uint64_t)req, + .sg_list = &sge, + .num_sge = 1, + }; + + sge.addr = (uint64_t)data[0]; + sge.length = sizes[0]; + sge.lkey = mr->lkey; + + ibv_post_recv(comm->qp, &wr, &bad_wr); +} +``` + +#### Completion Polling + +```c +ncclResult_t mesh_test(void *request, int *done, int *sizes) { + struct ibv_wc wc; + + int ret = ibv_poll_cq(req->cq, 1, &wc); + if (ret > 0) { + if (wc.status == IBV_WC_SUCCESS) { + *done = 1; + if (sizes) *sizes = wc.byte_len; + } else { + // Handle error + } + } else { + *done = 0; // Not complete yet + } +} +``` + +## Memory Registration + +RDMA requires memory to be registered with the NIC: + +```c +ncclResult_t mesh_regMr(void *comm, void *data, size_t size, + int type, void **mhandle) { + int access = IBV_ACCESS_LOCAL_WRITE | + IBV_ACCESS_REMOTE_WRITE | + IBV_ACCESS_REMOTE_READ; + + mrh->mr = ibv_reg_mr(nic->pd, data, size, access); + *mhandle = mrh; +} +``` + +**Note**: Current implementation uses host memory staging. GPU memory is copied to host, sent via RDMA, then copied back to GPU on the receiver. GPUDirect RDMA would eliminate these copies. + +## Performance Considerations + +### Current Bottlenecks + +1. **Host Memory Staging**: GPU↔Host copies add latency +2. **Single QP**: One Queue Pair per connection limits parallelism +3. **Completion Signaling**: Every operation signals completion + +### Achieved Performance + +- **8+ GB/s** effective bandwidth +- **~64%** of 100 Gbps line rate +- Sufficient for distributed ML workloads + +### Future Optimizations + +1. **GPUDirect RDMA**: Register GPU memory directly +2. **Multi-QP**: Multiple QPs per connection +3. **Selective Signaling**: Signal every N operations +4. **Inline Data**: Small messages in WQE + +## File Structure + +``` +nccl-mesh-plugin/ +β”œβ”€β”€ src/ +β”‚ └── mesh_plugin.c # Main implementation (~1400 lines) +β”œβ”€β”€ include/ +β”‚ └── mesh_plugin.h # Data structures and declarations +β”œβ”€β”€ nccl/ +β”‚ β”œβ”€β”€ net.h # NCCL net plugin interface +β”‚ β”œβ”€β”€ net_v8.h # v8 properties structure +β”‚ └── err.h # NCCL error codes +└── Makefile +``` + +## Debugging + +Enable debug output: + +```bash +export NCCL_DEBUG=INFO +export NCCL_MESH_DEBUG=1 +``` + +Common issues: + +1. **"No local NIC found"**: Subnet mismatch, check IP configuration +2. **"Handshake timeout"**: Firewall blocking TCP, check ports +3. **"QP transition failed"**: GID index wrong, try different `NCCL_MESH_GID_INDEX` +4. **"WC error status=12"**: Transport retry exceeded, check RDMA connectivity + +## Conclusion + +The NCCL Mesh Plugin demonstrates that with careful engineering, NCCL can be extended to support unconventional network topologies. The key innovationsβ€”multi-address handles, subnet-aware NIC selection, and asynchronous handshakingβ€”provide a template for other custom NCCL transports. diff --git a/docs/SETUP.md b/docs/SETUP.md new file mode 100644 index 0000000..f37f982 --- /dev/null +++ b/docs/SETUP.md @@ -0,0 +1,249 @@ +# Hardware Setup Guide + +This guide covers setting up a direct-connect RDMA mesh topology with multiple nodes. + +## Overview + +Our reference setup uses three NVIDIA DGX Spark workstations connected in a triangle mesh topology. Each pair of nodes has a dedicated 100 Gbps RDMA link on its own subnet. + +## Hardware Requirements + +- 3+ nodes with RDMA-capable NICs (ConnectX-6/7 recommended) +- Direct-attach cables (QSFP56 for 100GbE) +- Each node needs N-1 NICs for N nodes in a fully-connected mesh + +## Network Topology + +### Triangle Mesh (3 Nodes) + +``` + Node A + / \ + NIC1 NIC2 + | | +192.168.101.x 192.168.100.x + | | + NIC1 NIC1 + | | + Node B ---- Node C + NIC2 + 192.168.102.x +``` + +### IP Address Assignment + +| Link | Subnet | Node A | Node B | Node C | +|------|--------|--------|--------|--------| +| A↔B | 192.168.101.0/24 | .2 | .3 | - | +| A↔C | 192.168.100.0/24 | .2 | - | .3 | +| B↔C | 192.168.102.0/24 | - | .2 | .3 | + +## Network Configuration + +### 1. Identify NICs + +```bash +# List RDMA devices +ibv_devices + +# List network interfaces with RDMA +ls -la /sys/class/infiniband/*/device/net/ +``` + +### 2. Configure IP Addresses + +On **Node A** (example): + +```bash +# Link to Node B +sudo ip addr add 192.168.101.2/24 dev enp1s0f0np0 +sudo ip link set enp1s0f0np0 up + +# Link to Node C +sudo ip addr add 192.168.100.2/24 dev enp1s0f1np1 +sudo ip link set enp1s0f1np1 up +``` + +On **Node B**: + +```bash +# Link to Node A +sudo ip addr add 192.168.101.3/24 dev enp1s0f0np0 +sudo ip link set enp1s0f0np0 up + +# Link to Node C +sudo ip addr add 192.168.102.2/24 dev enp1s0f1np1 +sudo ip link set enp1s0f1np1 up +``` + +On **Node C**: + +```bash +# Link to Node A +sudo ip addr add 192.168.100.3/24 dev enp1s0f0np0 +sudo ip link set enp1s0f0np0 up + +# Link to Node B +sudo ip addr add 192.168.102.3/24 dev enp1s0f1np1 +sudo ip link set enp1s0f1np1 up +``` + +### 3. Make Configuration Persistent + +Create netplan config (Ubuntu): + +```yaml +# /etc/netplan/99-rdma-mesh.yaml +network: + version: 2 + ethernets: + enp1s0f0np0: + addresses: + - 192.168.101.2/24 # Adjust per node + enp1s0f1np1: + addresses: + - 192.168.100.2/24 # Adjust per node +``` + +Apply: +```bash +sudo netplan apply +``` + +## Verify Connectivity + +### 1. Ping Test + +From Node A: +```bash +ping 192.168.101.3 # Node B +ping 192.168.100.3 # Node C +``` + +### 2. RDMA Test + +```bash +# On Node B (server) +ib_send_bw -d rocep1s0f0 -x 3 + +# On Node A (client) +ib_send_bw -d rocep1s0f0 -x 3 192.168.101.3 +``` + +Expected output: ~12 GB/s for 100GbE + +### 3. Verify GID Index + +```bash +# Show GID table +show_gids + +# Find RoCE v2 GID (usually index 3) +ibv_devinfo -v | grep -A5 GID +``` + +## RoCE Configuration + +### Enable RoCE v2 + +```bash +# Check current mode +cat /sys/class/infiniband/rocep*/ports/1/gid_attrs/types/* + +# Enable RoCE v2 (if needed) +echo "RoCE v2" | sudo tee /sys/class/infiniband/rocep1s0f0/ports/1/gid_attrs/types/0 +``` + +### Configure ECN (Optional but Recommended) + +```bash +# Enable ECN for RoCE +sudo sysctl -w net.ipv4.tcp_ecn=1 + +# Configure PFC (Priority Flow Control) on switch if applicable +``` + +## Firewall Configuration + +Open ports for NCCL communication: + +```bash +# TCP ports for handshake (dynamic, 40000-50000 range) +sudo ufw allow 40000:50000/tcp + +# Or disable firewall for mesh interfaces +sudo ufw allow in on enp1s0f0np0 +sudo ufw allow in on enp1s0f1np1 +``` + +## Troubleshooting + +### No RDMA Devices Found + +```bash +# Load kernel modules +sudo modprobe ib_core +sudo modprobe mlx5_core +sudo modprobe mlx5_ib + +# Check dmesg +dmesg | grep -i mlx +``` + +### Link Not Coming Up + +```bash +# Check physical connection +ethtool enp1s0f0np0 + +# Check for errors +ip -s link show enp1s0f0np0 +``` + +### RDMA Connection Fails + +```bash +# Verify GID is populated +cat /sys/class/infiniband/rocep1s0f0/ports/1/gids/3 + +# Check RDMA CM +rdma link show +``` + +### Wrong GID Index + +Try different GID indices: + +```bash +export NCCL_MESH_GID_INDEX=0 # or 1, 2, 3... +``` + +## Scaling Beyond 3 Nodes + +For N nodes in a fully-connected mesh: +- Each node needs N-1 NICs +- Total links: N*(N-1)/2 +- Each link on unique subnet + +For 4 nodes: +``` + A + /|\ + B-+-C + \|/ + D +``` +- 6 links, 6 subnets +- Each node needs 3 NICs + +For larger clusters, consider a **partial mesh** or **fat-tree** topology with relay routing (not yet implemented in this plugin). + +## Reference: DGX Spark Mesh + +Our tested configuration: + +| Hostname | Management IP | Mesh IPs | +|----------|--------------|----------| +| titanic (A) | 10.0.0.170 | 192.168.100.2, 192.168.101.2 | +| iceberg (B) | 10.0.0.171 | 192.168.101.3, 192.168.102.2 | +| carpathia (C) | 10.0.0.172 | 192.168.100.3, 192.168.102.3 | diff --git a/examples/benchmark_bandwidth.py b/examples/benchmark_bandwidth.py new file mode 100644 index 0000000..53b34a7 --- /dev/null +++ b/examples/benchmark_bandwidth.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +""" +Bandwidth benchmark for NCCL Mesh Plugin + +Usage: + # On each node (adjust --rank): + python benchmark_bandwidth.py --rank 0 --world-size 3 --master-ip 10.0.0.170 +""" + +import argparse +import time +import torch +import torch.distributed as dist + + +def benchmark_allreduce(size_mb: int, iterations: int, warmup: int = 5): + """Benchmark all-reduce bandwidth""" + + # Create tensor + num_elements = (size_mb * 1024 * 1024) // 4 # float32 = 4 bytes + tensor = torch.ones(num_elements, device='cuda', dtype=torch.float32) + + # Warmup + for _ in range(warmup): + dist.all_reduce(tensor) + torch.cuda.synchronize() + + # Benchmark + start = time.perf_counter() + for _ in range(iterations): + dist.all_reduce(tensor) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + + # Calculate bandwidth + # All-reduce transfers 2*(N-1)/N * size data in ring algorithm + total_data_gb = (size_mb * iterations) / 1024 + bandwidth_gbs = total_data_gb / elapsed + + return bandwidth_gbs, elapsed + + +def main(): + parser = argparse.ArgumentParser(description='Benchmark NCCL bandwidth') + parser.add_argument('--rank', type=int, required=True) + parser.add_argument('--world-size', type=int, default=3) + parser.add_argument('--master-ip', type=str, default='10.0.0.170') + parser.add_argument('--master-port', type=int, default=29500) + parser.add_argument('--iterations', type=int, default=20) + args = parser.parse_args() + + # Initialize + init_method = f'tcp://{args.master_ip}:{args.master_port}' + dist.init_process_group('nccl', rank=args.rank, world_size=args.world_size, + init_method=init_method) + + if args.rank == 0: + print(f'\n{"="*60}') + print(f'NCCL Mesh Plugin Bandwidth Benchmark') + print(f'World size: {args.world_size}') + print(f'Iterations per size: {args.iterations}') + print(f'{"="*60}\n') + print(f'{"Size":<12} {"Bandwidth":<15} {"Time":<12}') + print(f'{"-"*12} {"-"*15} {"-"*12}') + + # Test different sizes + sizes_mb = [1, 4, 16, 64, 128, 256, 512] + + for size_mb in sizes_mb: + bandwidth, elapsed = benchmark_allreduce(size_mb, args.iterations) + + if args.rank == 0: + print(f'{size_mb:>6} MB {bandwidth:>8.2f} GB/s {elapsed:>6.3f} s') + + # Sync between sizes + dist.barrier() + + if args.rank == 0: + print(f'\n{"="*60}') + print('Benchmark complete!') + print(f'{"="*60}\n') + + dist.destroy_process_group() + + +if __name__ == '__main__': + main() diff --git a/examples/distributed_llm.py b/examples/distributed_llm.py new file mode 100644 index 0000000..9ced05b --- /dev/null +++ b/examples/distributed_llm.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +""" +Distributed LLM Inference with NCCL Mesh Plugin + +This example demonstrates loading and running inference on a large language +model distributed across multiple GPUs using the NCCL Mesh Plugin. + +Usage: + # On each node (adjust --rank): + python distributed_llm.py --rank 0 --world-size 3 --master-ip 10.0.0.170 + +Environment setup (run on each node): + cd ~/nccl-mesh-plugin + export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH + export NCCL_NET_PLUGIN=mesh + export NCCL_DEBUG=WARN +""" + +import argparse +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from accelerate import Accelerator + + +def main(): + parser = argparse.ArgumentParser(description='Distributed LLM Inference') + parser.add_argument('--rank', type=int, required=True) + parser.add_argument('--world-size', type=int, default=3) + parser.add_argument('--master-ip', type=str, default='10.0.0.170') + parser.add_argument('--master-port', type=int, default=29500) + parser.add_argument('--model', type=str, default='mistralai/Mistral-7B-Instruct-v0.2', + help='Model to load (default: Mistral-7B)') + parser.add_argument('--prompt', type=str, + default='The future of distributed AI computing is', + help='Prompt for generation') + parser.add_argument('--max-tokens', type=int, default=100, + help='Maximum tokens to generate') + args = parser.parse_args() + + # Initialize accelerator + accelerator = Accelerator() + + print(f'Rank {accelerator.process_index}: Loading tokenizer...') + tokenizer = AutoTokenizer.from_pretrained(args.model) + + print(f'Rank {accelerator.process_index}: Loading model...') + model = AutoModelForCausalLM.from_pretrained( + args.model, + torch_dtype=torch.bfloat16, + device_map='auto', + ) + + print(f'Rank {accelerator.process_index}: Model loaded!') + + # Only rank 0 generates + if accelerator.is_main_process: + print(f'\nGenerating text...') + print(f'Prompt: "{args.prompt}"\n') + + inputs = tokenizer(args.prompt, return_tensors='pt').to('cuda') + + outputs = model.generate( + **inputs, + max_new_tokens=args.max_tokens, + do_sample=True, + temperature=0.7, + top_p=0.9, + ) + + result = tokenizer.decode(outputs[0], skip_special_tokens=True) + + print('=' * 60) + print('Generated Text:') + print('=' * 60) + print(result) + print('=' * 60) + + # Wait for all ranks + accelerator.wait_for_everyone() + print(f'Rank {accelerator.process_index}: Done!') + + +if __name__ == '__main__': + main() diff --git a/examples/test_allreduce.py b/examples/test_allreduce.py new file mode 100644 index 0000000..0b29b02 --- /dev/null +++ b/examples/test_allreduce.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +""" +Basic all-reduce test for NCCL Mesh Plugin + +Usage: + # On rank 0: + python test_allreduce.py --rank 0 --world-size 3 --master-ip 10.0.0.170 + + # On rank 1: + python test_allreduce.py --rank 1 --world-size 3 --master-ip 10.0.0.170 + + # On rank 2: + python test_allreduce.py --rank 2 --world-size 3 --master-ip 10.0.0.170 +""" + +import argparse +import torch +import torch.distributed as dist + + +def main(): + parser = argparse.ArgumentParser(description='Test NCCL all-reduce') + parser.add_argument('--rank', type=int, required=True, help='Rank of this process') + parser.add_argument('--world-size', type=int, default=3, help='Total number of processes') + parser.add_argument('--master-ip', type=str, default='10.0.0.170', help='Master node IP') + parser.add_argument('--master-port', type=int, default=29500, help='Master node port') + args = parser.parse_args() + + # Initialize process group + init_method = f'tcp://{args.master_ip}:{args.master_port}' + print(f'Rank {args.rank}: Initializing with {init_method}') + + dist.init_process_group( + backend='nccl', + rank=args.rank, + world_size=args.world_size, + init_method=init_method + ) + + print(f'Rank {args.rank}: Process group initialized') + + # Create tensor on GPU + tensor = torch.ones(1000, device='cuda') + print(f'Rank {args.rank}: Created tensor with sum = {tensor.sum().item()}') + + # All-reduce (sum) + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + + result = tensor[0].item() + expected = float(args.world_size) + + print(f'Rank {args.rank}: After all-reduce, tensor[0] = {result}') + + if abs(result - expected) < 0.001: + print(f'Rank {args.rank}: βœ“ SUCCESS! Result matches expected value {expected}') + else: + print(f'Rank {args.rank}: βœ— FAILED! Expected {expected}, got {result}') + + # Cleanup + dist.destroy_process_group() + print(f'Rank {args.rank}: Done') + + +if __name__ == '__main__': + main() diff --git a/include/mesh_plugin.h b/include/mesh_plugin.h new file mode 100644 index 0000000..ca57e2e --- /dev/null +++ b/include/mesh_plugin.h @@ -0,0 +1,257 @@ +/* + * NCCL Mesh Plugin - Subnet-aware RDMA transport + * + * Enables NCCL to work with direct-connect mesh topologies where + * each node pair is on a different subnet. + */ + +#ifndef NCCL_MESH_PLUGIN_H +#define NCCL_MESH_PLUGIN_H + +#include +#include +#include + +#define MESH_MAX_NICS 8 +#define MESH_MAX_QPS 256 +#define MESH_MAX_MRS 1024 +#define MESH_HANDLE_MAGIC 0x4D455348 // "MESH" + +// Forward declarations +struct mesh_plugin_state; +struct mesh_nic; +struct mesh_comm; + +/* + * Represents one RDMA-capable NIC with its subnet information + */ +struct mesh_nic { + // RDMA resources + struct ibv_context *context; + struct ibv_pd *pd; + int port_num; + int gid_index; + + // Network addressing + uint32_t ip_addr; // Host byte order + uint32_t netmask; // Host byte order + uint32_t subnet; // ip_addr & netmask + + // Device identification + char dev_name[64]; // RDMA device name (e.g., "rocep1s0f1") + char if_name[64]; // Network interface name (e.g., "enp1s0f1np1") + char pci_path[256]; // PCI bus path + + // Capabilities + int max_qp; + int max_cq; + int max_mr; + int max_sge; + uint64_t max_mr_size; + int gdr_supported; // GPUDirect RDMA support + + // Statistics + uint64_t bytes_sent; + uint64_t bytes_recv; + uint64_t connections; +}; + +/* + * Address entry for multi-homed hosts + */ +#define MESH_MAX_ADDRS 6 + +struct mesh_addr_entry { + uint32_t ip; // IP address (network byte order) + uint32_t mask; // Subnet mask (network byte order) + uint16_t qp_num; // QP number for this NIC + uint8_t nic_idx; // Index into our NIC array + uint8_t gid_index; // GID index for this NIC +}; + +/* + * Connection handle - exchanged between peers during setup + * Must fit within NCCL_NET_HANDLE_MAXSIZE (128 bytes) + */ +struct mesh_handle { + uint32_t magic; // MESH_HANDLE_MAGIC + uint8_t num_addrs; // Number of valid addresses + uint8_t selected_idx; // Which address was selected (set by connect) + uint16_t lid; // IB LID (0 for RoCE) + uint16_t qp_num; // QP number (for compat with mesh_connect_qp) + uint16_t handshake_port; // TCP port for QP handshake + uint8_t port_num; // Port number (usually 1) + uint8_t mtu; // MTU setting + uint32_t psn; // Packet sequence number + uint32_t handshake_ip; // IP address for handshake (network byte order) + union ibv_gid gid; // GID (16 bytes) + struct mesh_addr_entry addrs[MESH_MAX_ADDRS]; // 12 bytes each + // Total: 4+1+1+2+2+2+1+1+4+4+16 + 6*12 = 38 + 72 = 110 bytes (fits in 128) +}; + +/* + * Listen state - waiting for incoming connections + * Creates QPs on ALL NICs so any peer can connect + */ +#define HANDSHAKE_QUEUE_SIZE 16 + +/* + * QP info exchanged during handshake + */ +struct mesh_qp_info { + uint32_t qp_num; // Network byte order + uint32_t psn; // Network byte order + uint8_t gid[16]; // Raw GID + uint32_t ip; // Network byte order + uint8_t gid_index; + uint8_t nic_idx; // Which NIC on the listener + uint8_t reserved[2]; +}; + +struct handshake_entry { + struct mesh_qp_info remote_info; + struct ibv_qp *local_qp; + struct ibv_cq *local_cq; + struct mesh_nic *nic; + int valid; +}; + +struct mesh_listen_comm { + int num_qps; + struct { + struct mesh_nic *nic; + struct ibv_qp *qp; + struct ibv_cq *cq; + } qps[MESH_MAX_NICS]; + uint32_t psn; + int ready; + + // Handshake socket for QP info exchange + int handshake_sock; + uint16_t handshake_port; + uint32_t handshake_ip; + + // Background handshake thread + pthread_t handshake_thread; + int thread_running; + int thread_stop; + + // Queue of received handshakes for accept() to consume + struct handshake_entry handshake_queue[HANDSHAKE_QUEUE_SIZE]; + int queue_head; + int queue_tail; + pthread_mutex_t queue_mutex; + pthread_cond_t queue_cond; +}; + +/* + * Send/Receive communication state + */ +struct mesh_send_comm { + struct mesh_nic *nic; + struct ibv_qp *qp; + struct ibv_cq *cq; + uint32_t remote_qp_num; + union ibv_gid remote_gid; + int connected; + + // Request tracking + struct mesh_request *requests[MESH_MAX_QPS]; + int num_requests; +}; + +struct mesh_recv_comm { + struct mesh_nic *nic; + struct ibv_qp *qp; + struct ibv_cq *cq; + int connected; + + // Request tracking + struct mesh_request *requests[MESH_MAX_QPS]; + int num_requests; +}; + +/* + * Memory registration handle + */ +struct mesh_mr_handle { + struct ibv_mr *mr; + struct mesh_nic *nic; + void *addr; + size_t size; +}; + +/* + * Async request state + */ +struct mesh_request { + int used; + int done; + size_t size; + struct ibv_cq *cq; // CQ to poll for completion + struct ibv_wc wc; +}; + +/* + * Global plugin state + */ +struct mesh_plugin_state { + struct mesh_nic nics[MESH_MAX_NICS]; + int num_nics; + int initialized; + + // Configuration + int gid_index; // From NCCL_MESH_GID_INDEX + int debug; // From NCCL_MESH_DEBUG + + // Logging (provided by NCCL) + void (*log_fn)(int level, unsigned long flags, const char *file, + int line, const char *fmt, ...); +}; + +// Global state (singleton) +extern struct mesh_plugin_state g_mesh_state; + +/* + * Internal functions + */ + +// Initialization +int mesh_init_nics(void); +int mesh_discover_nic_ips(void); +int mesh_setup_nic(struct mesh_nic *nic, struct ibv_device *device); + +// Routing +struct mesh_nic* mesh_find_nic_for_ip(uint32_t peer_ip); +struct mesh_nic* mesh_find_nic_by_name(const char *name); +int mesh_get_nic_index(struct mesh_nic *nic); + +// RDMA operations +int mesh_create_qp(struct mesh_nic *nic, struct ibv_qp **qp, struct ibv_cq **cq); +int mesh_connect_qp(struct ibv_qp *qp, struct mesh_nic *nic, struct mesh_handle *remote); +int mesh_post_send(struct mesh_send_comm *comm, void *data, size_t size, + struct mesh_mr_handle *mr, struct mesh_request *req); +int mesh_post_recv(struct mesh_recv_comm *comm, void *data, size_t size, + struct mesh_mr_handle *mr, struct mesh_request *req); +int mesh_poll_cq(struct ibv_cq *cq, struct mesh_request *req); + +// Utilities +uint32_t mesh_ip_to_uint(const char *ip_str); +void mesh_uint_to_ip(uint32_t ip, char *buf, size_t len); +int mesh_get_interface_ip(const char *if_name, uint32_t *ip, uint32_t *mask); +const char* mesh_find_netdev_for_rdma(const char *rdma_dev); + +// Logging macros +#define MESH_LOG(level, fmt, ...) \ + do { \ + if (g_mesh_state.log_fn) { \ + g_mesh_state.log_fn(level, 0, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \ + } \ + } while(0) + +#define MESH_INFO(fmt, ...) MESH_LOG(NCCL_LOG_INFO, "MESH " fmt, ##__VA_ARGS__) +#define MESH_WARN(fmt, ...) MESH_LOG(NCCL_LOG_WARN, "MESH " fmt, ##__VA_ARGS__) +#define MESH_DEBUG(fmt, ...) \ + do { if (g_mesh_state.debug) MESH_LOG(NCCL_LOG_TRACE, "MESH " fmt, ##__VA_ARGS__); } while(0) + +#endif // NCCL_MESH_PLUGIN_H diff --git a/nccl/err.h b/nccl/err.h new file mode 100644 index 0000000..ff29a06 --- /dev/null +++ b/nccl/err.h @@ -0,0 +1,47 @@ +/* + * NCCL error codes - extracted from NCCL headers + */ + +#ifndef NCCL_ERR_H +#define NCCL_ERR_H + +typedef enum { + ncclSuccess = 0, + ncclUnhandledCudaError = 1, + ncclSystemError = 2, + ncclInternalError = 3, + ncclInvalidArgument = 4, + ncclInvalidUsage = 5, + ncclRemoteError = 6, + ncclInProgress = 7, + ncclNumResults = 8 +} ncclResult_t; + +// Logging levels +#define NCCL_LOG_NONE 0 +#define NCCL_LOG_VERSION 1 +#define NCCL_LOG_WARN 2 +#define NCCL_LOG_INFO 3 +#define NCCL_LOG_ABORT 4 +#define NCCL_LOG_TRACE 5 + +// Debug logger function type +typedef void (*ncclDebugLogger_t)(int level, unsigned long flags, + const char *file, int line, const char *fmt, ...); + +// Pointer support flags +#define NCCL_PTR_HOST 0x1 +#define NCCL_PTR_CUDA 0x2 +#define NCCL_PTR_DMABUF 0x4 + +// Maximum handle size +#define NCCL_NET_HANDLE_MAXSIZE 128 + +// Net device types +#define NCCL_NET_DEVICE_HOST 0 +#define NCCL_NET_DEVICE_INVALID_VERSION 0 + +// Maximum sizes +#define NCCL_MAX_NET_SIZE_BYTES (1ULL << 31) + +#endif // NCCL_ERR_H diff --git a/nccl/net.h b/nccl/net.h new file mode 100644 index 0000000..9be245c --- /dev/null +++ b/nccl/net.h @@ -0,0 +1,18 @@ +/* + * NCCL Net Plugin API - main header + */ + +#ifndef NCCL_NET_H +#define NCCL_NET_H + +#include "err.h" +#include "net_v8.h" + +// Maximum number of outstanding requests +#define NCCL_NET_MAX_REQUESTS 32 + +// Use v8 as current version +typedef ncclNet_v8_t ncclNet_t; +typedef ncclNetProperties_v8_t ncclNetProperties_t; + +#endif // NCCL_NET_H diff --git a/nccl/net_v8.h b/nccl/net_v8.h new file mode 100644 index 0000000..b06f337 --- /dev/null +++ b/nccl/net_v8.h @@ -0,0 +1,101 @@ +/* + * NCCL Net Plugin API v8 - extracted from NCCL headers + */ + +#ifndef NCCL_NET_V8_H +#define NCCL_NET_V8_H + +#include "err.h" +#include +#include + +// Network device handle (opaque to NCCL) +typedef void* ncclNetDeviceHandle_t; + +// Network properties structure (v8) +typedef struct { + char* name; // Used mostly for logging + char* pciPath; // Path to the PCI device in /sys + uint64_t guid; // Unique identifier for the NIC chip + int ptrSupport; // NCCL_PTR_HOST or NCCL_PTR_HOST|NCCL_PTR_CUDA + int speed; // Port speed in Mbps + int port; // Port number + float latency; // Network latency in microseconds + int maxComms; // Maximum number of comms we can create + int maxRecvs; // Maximum number of grouped receives + int netDeviceType; // Network device type + int netDeviceVersion; // Network device version + uint64_t maxP2pBytes; // Maximum P2P transfer size +} ncclNetProperties_v8_t; + +// Net plugin structure v8 +typedef struct { + // Name of the network (mainly for logs) + const char* name; + + // Initialize the network + ncclResult_t (*init)(ncclDebugLogger_t logFunction); + + // Return the number of adapters + ncclResult_t (*devices)(int* ndev); + + // Get various device properties + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v8_t* props); + + // Create a receiving object and provide a handle to connect to it + // The handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be + // exchanged between ranks to create a connection + ncclResult_t (*listen)(int dev, void* handle, void** listenComm); + + // Connect to a handle and return a sending comm object for that peer + // This call must not block for the connection to be established, and + // instead should return ncclSuccess with sendComm == NULL if the + // connection is not established yet + ncclResult_t (*connect)(int dev, void* handle, void** sendComm, + ncclNetDeviceHandle_t** sendDevComm); + + // Finalize connection establishment after remote peer has called connect + // This call must not block for the connection to be established, and + // instead should return ncclSuccess with recvComm == NULL if the + // connection is not established yet + ncclResult_t (*accept)(void* listenComm, void** recvComm, + ncclNetDeviceHandle_t** recvDevComm); + + // Register/deregister memory for use with send/recv + ncclResult_t (*regMr)(void* comm, void* data, size_t size, int type, + void** mhandle); + ncclResult_t (*regMrDmaBuf)(void* comm, void* data, size_t size, int type, + uint64_t offset, int fd, void** mhandle); + ncclResult_t (*deregMr)(void* comm, void* mhandle); + + // Asynchronous send to a peer + // May return ncclInProgress if the operation cannot be posted immediately + ncclResult_t (*isend)(void* sendComm, void* data, int size, int tag, + void* mhandle, void** request); + + // Asynchronous receive from a peer + // May return ncclInProgress if the operation cannot be posted immediately + ncclResult_t (*irecv)(void* recvComm, int n, void** data, int* sizes, + int* tags, void** mhandles, void** request); + + // Flush data received through irecv + ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, + void** mhandles, void** request); + + // Test whether a request has completed + ncclResult_t (*test)(void* request, int* done, int* sizes); + + // Close and free send/recv comm objects + ncclResult_t (*closeSend)(void* sendComm); + ncclResult_t (*closeRecv)(void* recvComm); + ncclResult_t (*closeListen)(void* listenComm); + + // Get device-side memory handle for registered memory + ncclResult_t (*getDeviceMr)(void* comm, void* mhandle, void** dptr_mhandle); + + // Notify that irecv has been consumed + ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request); + +} ncclNet_v8_t; + +#endif // NCCL_NET_V8_H diff --git a/src/mesh_plugin.c b/src/mesh_plugin.c new file mode 100644 index 0000000..275bcd4 --- /dev/null +++ b/src/mesh_plugin.c @@ -0,0 +1,1508 @@ +/* + * NCCL Mesh Plugin - Main Implementation + * + * Subnet-aware RDMA transport for direct-connect mesh topologies + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "nccl/net.h" +#include "mesh_plugin.h" + +// Global state +struct mesh_plugin_state g_mesh_state = {0}; + +// Plugin name +#define PLUGIN_NAME "Mesh" + +/* + * Utility: Convert IP string to uint32 + */ +uint32_t mesh_ip_to_uint(const char *ip_str) { + struct in_addr addr; + if (inet_pton(AF_INET, ip_str, &addr) != 1) { + return 0; + } + return ntohl(addr.s_addr); +} + +/* + * Utility: Convert uint32 to IP string + */ +void mesh_uint_to_ip(uint32_t ip, char *buf, size_t len) { + struct in_addr addr; + addr.s_addr = htonl(ip); + inet_ntop(AF_INET, &addr, buf, len); +} + +/* + * Get IP address and netmask for a network interface + */ +int mesh_get_interface_ip(const char *if_name, uint32_t *ip, uint32_t *mask) { + struct ifaddrs *ifaddr, *ifa; + int found = 0; + + if (getifaddrs(&ifaddr) == -1) { + return -1; + } + + for (ifa = ifaddr; ifa != NULL; ifa = ifa->ifa_next) { + if (ifa->ifa_addr == NULL) continue; + if (ifa->ifa_addr->sa_family != AF_INET) continue; + if (strcmp(ifa->ifa_name, if_name) != 0) continue; + + struct sockaddr_in *addr = (struct sockaddr_in *)ifa->ifa_addr; + struct sockaddr_in *netmask = (struct sockaddr_in *)ifa->ifa_netmask; + + *ip = ntohl(addr->sin_addr.s_addr); + *mask = ntohl(netmask->sin_addr.s_addr); + found = 1; + break; + } + + freeifaddrs(ifaddr); + return found ? 0 : -1; +} + +/* + * Find network interface name for an RDMA device + * Looks in /sys/class/infiniband//device/net/ + */ +const char* mesh_find_netdev_for_rdma(const char *rdma_dev) { + static char netdev[64]; + char path[256]; + DIR *dir; + struct dirent *entry; + + snprintf(path, sizeof(path), "/sys/class/infiniband/%s/device/net", rdma_dev); + dir = opendir(path); + if (!dir) { + return NULL; + } + + while ((entry = readdir(dir)) != NULL) { + if (entry->d_name[0] != '.') { + strncpy(netdev, entry->d_name, sizeof(netdev) - 1); + netdev[sizeof(netdev) - 1] = '\0'; + closedir(dir); + return netdev; + } + } + + closedir(dir); + return NULL; +} + +/* + * Find the NIC that can reach a given IP address (same subnet) + */ +struct mesh_nic* mesh_find_nic_for_ip(uint32_t peer_ip) { + for (int i = 0; i < g_mesh_state.num_nics; i++) { + struct mesh_nic *nic = &g_mesh_state.nics[i]; + uint32_t peer_subnet = peer_ip & nic->netmask; + + MESH_DEBUG("Checking NIC %s: peer_ip=0x%x, subnet=0x%x, nic_subnet=0x%x", + nic->dev_name, peer_ip, peer_subnet, nic->subnet); + + if (peer_subnet == nic->subnet) { + MESH_DEBUG("Found matching NIC %s for peer IP 0x%x", + nic->dev_name, peer_ip); + return nic; + } + } + + MESH_WARN("No NIC found for peer IP 0x%x", peer_ip); + return NULL; +} + +/* + * Setup a single NIC + */ +int mesh_setup_nic(struct mesh_nic *nic, struct ibv_device *device) { + struct ibv_device_attr dev_attr; + struct ibv_port_attr port_attr; + + // Open context + nic->context = ibv_open_device(device); + if (!nic->context) { + MESH_WARN("Failed to open device %s", ibv_get_device_name(device)); + return -1; + } + + // Get device name + strncpy(nic->dev_name, ibv_get_device_name(device), sizeof(nic->dev_name) - 1); + + // Find associated network interface + const char *netdev = mesh_find_netdev_for_rdma(nic->dev_name); + if (netdev) { + strncpy(nic->if_name, netdev, sizeof(nic->if_name) - 1); + + // Get IP address + if (mesh_get_interface_ip(nic->if_name, &nic->ip_addr, &nic->netmask) == 0) { + nic->subnet = nic->ip_addr & nic->netmask; + + char ip_str[INET_ADDRSTRLEN], mask_str[INET_ADDRSTRLEN], subnet_str[INET_ADDRSTRLEN]; + mesh_uint_to_ip(nic->ip_addr, ip_str, sizeof(ip_str)); + mesh_uint_to_ip(nic->netmask, mask_str, sizeof(mask_str)); + mesh_uint_to_ip(nic->subnet, subnet_str, sizeof(subnet_str)); + + MESH_INFO("NIC %s (%s): IP=%s, mask=%s, subnet=%s", + nic->dev_name, nic->if_name, ip_str, mask_str, subnet_str); + } else { + MESH_WARN("Could not get IP for interface %s", nic->if_name); + } + } else { + MESH_WARN("Could not find netdev for RDMA device %s", nic->dev_name); + } + + // Query device attributes + if (ibv_query_device(nic->context, &dev_attr)) { + MESH_WARN("Failed to query device %s", nic->dev_name); + ibv_close_device(nic->context); + return -1; + } + + nic->max_qp = dev_attr.max_qp; + nic->max_cq = dev_attr.max_cq; + nic->max_mr = dev_attr.max_mr; + nic->max_sge = dev_attr.max_sge; + nic->max_mr_size = dev_attr.max_mr_size; + + // Query port (assume port 1) + nic->port_num = 1; + if (ibv_query_port(nic->context, nic->port_num, &port_attr)) { + MESH_WARN("Failed to query port for %s", nic->dev_name); + ibv_close_device(nic->context); + return -1; + } + + // Allocate protection domain + nic->pd = ibv_alloc_pd(nic->context); + if (!nic->pd) { + MESH_WARN("Failed to allocate PD for %s", nic->dev_name); + ibv_close_device(nic->context); + return -1; + } + + // Use configured GID index or default to 3 (RoCE v2 with IPv4) + nic->gid_index = g_mesh_state.gid_index; + + MESH_INFO("Initialized NIC %s: max_qp=%d, max_mr=%d, gid_index=%d", + nic->dev_name, nic->max_qp, nic->max_mr, nic->gid_index); + + return 0; +} + +/* + * Initialize all NICs + */ +int mesh_init_nics(void) { + struct ibv_device **dev_list; + int num_devices; + + dev_list = ibv_get_device_list(&num_devices); + if (!dev_list) { + MESH_WARN("Failed to get RDMA device list"); + return -1; + } + + if (num_devices == 0) { + MESH_WARN("No RDMA devices found"); + ibv_free_device_list(dev_list); + return -1; + } + + MESH_INFO("Found %d RDMA devices", num_devices); + + g_mesh_state.num_nics = 0; + for (int i = 0; i < num_devices && g_mesh_state.num_nics < MESH_MAX_NICS; i++) { + struct mesh_nic *nic = &g_mesh_state.nics[g_mesh_state.num_nics]; + memset(nic, 0, sizeof(*nic)); + + if (mesh_setup_nic(nic, dev_list[i]) == 0) { + // Only count NICs that have an IP configured + if (nic->ip_addr != 0) { + g_mesh_state.num_nics++; + } else { + // Clean up NIC without IP + if (nic->pd) ibv_dealloc_pd(nic->pd); + if (nic->context) ibv_close_device(nic->context); + } + } + } + + ibv_free_device_list(dev_list); + + MESH_INFO("Initialized %d NICs with IP addresses", g_mesh_state.num_nics); + return g_mesh_state.num_nics > 0 ? 0 : -1; +} + +/* + * Create a listening socket for QP handshake + */ +int mesh_create_handshake_socket(uint32_t bind_ip, uint16_t *port_out) { + int sock; + struct sockaddr_in addr; + socklen_t addrlen = sizeof(addr); + int opt = 1; + + sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + MESH_WARN("Failed to create handshake socket: %s", strerror(errno)); + return -1; + } + + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(bind_ip); + addr.sin_port = 0; // Let OS choose port + + if (bind(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + MESH_WARN("Failed to bind handshake socket: %s", strerror(errno)); + close(sock); + return -1; + } + + if (listen(sock, 16) < 0) { + MESH_WARN("Failed to listen on handshake socket: %s", strerror(errno)); + close(sock); + return -1; + } + + // Get assigned port + if (getsockname(sock, (struct sockaddr *)&addr, &addrlen) < 0) { + MESH_WARN("Failed to get socket name: %s", strerror(errno)); + close(sock); + return -1; + } + + *port_out = ntohs(addr.sin_port); + + char ip_str[INET_ADDRSTRLEN]; + mesh_uint_to_ip(bind_ip, ip_str, sizeof(ip_str)); + MESH_INFO("Handshake socket listening on %s:%d", ip_str, *port_out); + + return sock; +} + +/* + * Accept a handshake connection, receive remote QP info, and send ours back + */ +int mesh_accept_handshake(int listen_sock, struct mesh_qp_info *remote_info, struct mesh_qp_info *local_info) { + int conn_sock; + struct sockaddr_in addr; + socklen_t addrlen = sizeof(addr); + + conn_sock = accept(listen_sock, (struct sockaddr *)&addr, &addrlen); + if (conn_sock < 0) { + MESH_WARN("Failed to accept handshake connection: %s", strerror(errno)); + return -1; + } + + char ip_str[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &addr.sin_addr, ip_str, sizeof(ip_str)); + MESH_INFO("Accepted handshake connection from %s:%d", ip_str, ntohs(addr.sin_port)); + + // Receive remote QP info + ssize_t n = recv(conn_sock, remote_info, sizeof(*remote_info), MSG_WAITALL); + if (n != sizeof(*remote_info)) { + MESH_WARN("Failed to receive QP info: got %zd bytes, expected %zu", n, sizeof(*remote_info)); + close(conn_sock); + return -1; + } + + MESH_INFO("Received remote QP info: qp_num=%u, psn=%u", + ntohl(remote_info->qp_num), ntohl(remote_info->psn)); + + // Send our QP info back + n = send(conn_sock, local_info, sizeof(*local_info), 0); + if (n != sizeof(*local_info)) { + MESH_WARN("Failed to send local QP info: sent %zd bytes, expected %zu", n, sizeof(*local_info)); + close(conn_sock); + return -1; + } + + MESH_INFO("Sent local QP info: qp_num=%u, psn=%u", + ntohl(local_info->qp_num), ntohl(local_info->psn)); + + close(conn_sock); + return 0; +} + +/* + * Connect, send our QP info, and receive remote's QP info + * Uses non-blocking connect with select() to avoid deadlock + */ +int mesh_send_handshake(uint32_t remote_ip, uint16_t remote_port, + struct mesh_qp_info *local_info, struct mesh_qp_info *remote_info) { + int sock; + struct sockaddr_in addr; + int retries = 100; // 10 seconds total + int connected = 0; + + char ip_str[INET_ADDRSTRLEN]; + mesh_uint_to_ip(remote_ip, ip_str, sizeof(ip_str)); + fprintf(stderr, "MESH DEBUG: send_handshake connecting to %s:%d\n", ip_str, remote_port); + fflush(stderr); + + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(remote_ip); + addr.sin_port = htons(remote_port); + + // Retry connection - peer's accept() might not be ready yet + while (retries > 0 && !connected) { + sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) { + MESH_WARN("Failed to create handshake socket: %s", strerror(errno)); + return -1; + } + + // Set non-blocking + int flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, flags | O_NONBLOCK); + + int ret = connect(sock, (struct sockaddr *)&addr, sizeof(addr)); + if (ret == 0) { + connected = 1; + } else if (errno == EINPROGRESS) { + // Wait for connection with select + fd_set writefds; + struct timeval tv; + FD_ZERO(&writefds); + FD_SET(sock, &writefds); + tv.tv_sec = 0; + tv.tv_usec = 100000; // 100ms + + ret = select(sock + 1, NULL, &writefds, NULL, &tv); + if (ret > 0) { + // Check if connection succeeded + int error = 0; + socklen_t len = sizeof(error); + getsockopt(sock, SOL_SOCKET, SO_ERROR, &error, &len); + if (error == 0) { + connected = 1; + } else { + close(sock); + retries--; + } + } else { + close(sock); + retries--; + } + } else { + close(sock); + retries--; + usleep(100000); // 100ms before retry + } + } + + if (!connected) { + MESH_WARN("Failed to connect handshake socket after retries"); + return -1; + } + + // Set back to blocking for send/recv + int flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, flags & ~O_NONBLOCK); + + fprintf(stderr, "MESH DEBUG: TCP handshake connected to %s:%d!\n", ip_str, remote_port); + fflush(stderr); + + // Send our QP info + ssize_t n = send(sock, local_info, sizeof(*local_info), 0); + if (n != sizeof(*local_info)) { + MESH_WARN("Failed to send QP info: sent %zd bytes, expected %zu", n, sizeof(*local_info)); + close(sock); + return -1; + } + + fprintf(stderr, "MESH DEBUG: Sent QP info, waiting for response...\n"); + fflush(stderr); + + // Receive remote's QP info (the accept side's NEW QP) + n = recv(sock, remote_info, sizeof(*remote_info), MSG_WAITALL); + if (n != sizeof(*remote_info)) { + MESH_WARN("Failed to receive remote QP info: got %zd bytes, expected %zu", n, sizeof(*remote_info)); + close(sock); + return -1; + } + + fprintf(stderr, "MESH DEBUG: Received remote QP info: qp_num=%u\n", ntohl(remote_info->qp_num)); + fflush(stderr); + + close(sock); + return 0; +} + +/* + * Background handshake thread + * Handles incoming TCP connections, creates QPs, and queues for accept() + */ +static void *handshake_thread_func(void *arg) { + struct mesh_listen_comm *lcomm = (struct mesh_listen_comm *)arg; + + fprintf(stderr, "MESH DEBUG: Handshake thread started, sock=%d\n", lcomm->handshake_sock); + fflush(stderr); + + // Set socket to non-blocking so we can check stop flag + int flags = fcntl(lcomm->handshake_sock, F_GETFL, 0); + fcntl(lcomm->handshake_sock, F_SETFL, flags | O_NONBLOCK); + + while (!lcomm->thread_stop) { + struct sockaddr_in addr; + socklen_t addrlen = sizeof(addr); + + int conn_sock = accept(lcomm->handshake_sock, (struct sockaddr *)&addr, &addrlen); + if (conn_sock < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + usleep(10000); // 10ms + continue; + } + break; + } + + char ip_str[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &addr.sin_addr, ip_str, sizeof(ip_str)); + fprintf(stderr, "MESH DEBUG: Handshake thread: connection from %s\n", ip_str); + fflush(stderr); + + // Receive remote QP info + struct mesh_qp_info remote_info; + ssize_t n = recv(conn_sock, &remote_info, sizeof(remote_info), MSG_WAITALL); + if (n != sizeof(remote_info)) { + fprintf(stderr, "MESH DEBUG: Handshake thread: failed to recv, got %zd\n", n); + close(conn_sock); + continue; + } + + fprintf(stderr, "MESH DEBUG: Handshake thread: received QP %u, nic_idx=%d\n", + ntohl(remote_info.qp_num), remote_info.nic_idx); + fflush(stderr); + + // Select NIC based on nic_idx from remote + int nic_idx = remote_info.nic_idx; + if (nic_idx >= lcomm->num_qps) nic_idx = 0; + struct mesh_nic *nic = lcomm->qps[nic_idx].nic; + + // Create new QP for this connection + struct ibv_qp *new_qp = NULL; + struct ibv_cq *new_cq = NULL; + if (mesh_create_qp(nic, &new_qp, &new_cq) != 0) { + fprintf(stderr, "MESH DEBUG: Handshake thread: failed to create QP\n"); + close(conn_sock); + continue; + } + + fprintf(stderr, "MESH DEBUG: Handshake thread: created QP %d on %s\n", + new_qp->qp_num, nic->dev_name); + fflush(stderr); + + // Connect our QP to remote's QP + struct mesh_handle connect_handle; + memset(&connect_handle, 0, sizeof(connect_handle)); + connect_handle.qp_num = ntohl(remote_info.qp_num); + connect_handle.psn = ntohl(remote_info.psn); + connect_handle.port_num = nic->port_num; + connect_handle.mtu = IBV_MTU_4096; + + // Construct GID from remote IP + union ibv_gid remote_gid; + memset(&remote_gid, 0, sizeof(remote_gid)); + remote_gid.raw[10] = 0xff; + remote_gid.raw[11] = 0xff; + uint32_t remote_ip = remote_info.ip; + memcpy(&remote_gid.raw[12], &remote_ip, 4); + connect_handle.gid = remote_gid; + + if (mesh_connect_qp(new_qp, nic, &connect_handle) != 0) { + fprintf(stderr, "MESH DEBUG: Handshake thread: failed to connect QP\n"); + ibv_destroy_qp(new_qp); + ibv_destroy_cq(new_cq); + close(conn_sock); + continue; + } + + fprintf(stderr, "MESH DEBUG: Handshake thread: QP connected to remote QP %d\n", + connect_handle.qp_num); + fflush(stderr); + + // Send our QP info back + struct mesh_qp_info local_info; + memset(&local_info, 0, sizeof(local_info)); + local_info.qp_num = htonl(new_qp->qp_num); + local_info.psn = htonl(0); + local_info.ip = htonl(nic->ip_addr); + local_info.nic_idx = nic_idx; + + n = send(conn_sock, &local_info, sizeof(local_info), 0); + close(conn_sock); + + if (n != sizeof(local_info)) { + fprintf(stderr, "MESH DEBUG: Handshake thread: failed to send response\n"); + ibv_destroy_qp(new_qp); + ibv_destroy_cq(new_cq); + continue; + } + + fprintf(stderr, "MESH DEBUG: Handshake thread: sent QP %d back, queueing for accept\n", + new_qp->qp_num); + fflush(stderr); + + // Queue this handshake for accept() to consume + pthread_mutex_lock(&lcomm->queue_mutex); + int next_tail = (lcomm->queue_tail + 1) % HANDSHAKE_QUEUE_SIZE; + if (next_tail != lcomm->queue_head) { + struct handshake_entry *entry = &lcomm->handshake_queue[lcomm->queue_tail]; + entry->remote_info = remote_info; + entry->local_qp = new_qp; + entry->local_cq = new_cq; + entry->nic = nic; + entry->valid = 1; + lcomm->queue_tail = next_tail; + pthread_cond_signal(&lcomm->queue_cond); + } else { + fprintf(stderr, "MESH DEBUG: Handshake thread: queue full!\n"); + ibv_destroy_qp(new_qp); + ibv_destroy_cq(new_cq); + } + pthread_mutex_unlock(&lcomm->queue_mutex); + } + + fprintf(stderr, "MESH DEBUG: Handshake thread exiting\n"); + fflush(stderr); + return NULL; +} + +/* + * Create QP and CQ on a NIC + */ +int mesh_create_qp(struct mesh_nic *nic, struct ibv_qp **qp_out, struct ibv_cq **cq_out) { + struct ibv_cq *cq; + struct ibv_qp *qp; + struct ibv_qp_init_attr qp_init_attr; + + // Create completion queue + cq = ibv_create_cq(nic->context, 128, NULL, NULL, 0); + if (!cq) { + MESH_WARN("Failed to create CQ on %s", nic->dev_name); + return -1; + } + + // Create queue pair + memset(&qp_init_attr, 0, sizeof(qp_init_attr)); + qp_init_attr.send_cq = cq; + qp_init_attr.recv_cq = cq; + qp_init_attr.qp_type = IBV_QPT_RC; + qp_init_attr.cap.max_send_wr = 64; + qp_init_attr.cap.max_recv_wr = 64; + qp_init_attr.cap.max_send_sge = 1; + qp_init_attr.cap.max_recv_sge = 1; + qp_init_attr.cap.max_inline_data = 64; + + qp = ibv_create_qp(nic->pd, &qp_init_attr); + if (!qp) { + MESH_WARN("Failed to create QP on %s", nic->dev_name); + ibv_destroy_cq(cq); + return -1; + } + + // Transition QP to INIT state + struct ibv_qp_attr qp_attr; + memset(&qp_attr, 0, sizeof(qp_attr)); + qp_attr.qp_state = IBV_QPS_INIT; + qp_attr.pkey_index = 0; + qp_attr.port_num = nic->port_num; + qp_attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; + + if (ibv_modify_qp(qp, &qp_attr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)) { + MESH_WARN("Failed to transition QP to INIT on %s", nic->dev_name); + ibv_destroy_qp(qp); + ibv_destroy_cq(cq); + return -1; + } + + *qp_out = qp; + *cq_out = cq; + return 0; +} + +/* + * Connect QP to remote peer + */ +int mesh_connect_qp(struct ibv_qp *qp, struct mesh_nic *nic, struct mesh_handle *remote) { + struct ibv_qp_attr qp_attr; + + // Transition to RTR + memset(&qp_attr, 0, sizeof(qp_attr)); + qp_attr.qp_state = IBV_QPS_RTR; + qp_attr.path_mtu = IBV_MTU_4096; + qp_attr.dest_qp_num = remote->qp_num; + qp_attr.rq_psn = remote->psn; + qp_attr.max_dest_rd_atomic = 1; + qp_attr.min_rnr_timer = 12; + qp_attr.ah_attr.is_global = 1; + qp_attr.ah_attr.grh.dgid = remote->gid; + qp_attr.ah_attr.grh.sgid_index = nic->gid_index; + qp_attr.ah_attr.grh.hop_limit = 64; + qp_attr.ah_attr.dlid = remote->lid; + qp_attr.ah_attr.sl = 0; + qp_attr.ah_attr.src_path_bits = 0; + qp_attr.ah_attr.port_num = nic->port_num; + + if (ibv_modify_qp(qp, &qp_attr, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER)) { + MESH_WARN("Failed to transition QP to RTR"); + return -1; + } + + // Transition to RTS + memset(&qp_attr, 0, sizeof(qp_attr)); + qp_attr.qp_state = IBV_QPS_RTS; + qp_attr.timeout = 14; + qp_attr.retry_cnt = 7; + qp_attr.rnr_retry = 7; + qp_attr.sq_psn = 0; + qp_attr.max_rd_atomic = 1; + + if (ibv_modify_qp(qp, &qp_attr, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | + IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC)) { + MESH_WARN("Failed to transition QP to RTS"); + return -1; + } + + return 0; +} + +/* + * ============================================================================ + * NCCL Plugin API Implementation + * ============================================================================ + */ + +static ncclResult_t mesh_init(ncclDebugLogger_t logFunction) { + if (g_mesh_state.initialized) { + return ncclSuccess; + } + + g_mesh_state.log_fn = logFunction; + + // Read configuration from environment + const char *gid_str = getenv("NCCL_MESH_GID_INDEX"); + g_mesh_state.gid_index = gid_str ? atoi(gid_str) : 3; + + const char *debug_str = getenv("NCCL_MESH_DEBUG"); + g_mesh_state.debug = debug_str ? atoi(debug_str) : 0; + + MESH_INFO("Initializing Mesh plugin (gid_index=%d, debug=%d)", + g_mesh_state.gid_index, g_mesh_state.debug); + + if (mesh_init_nics() != 0) { + MESH_WARN("Failed to initialize NICs"); + return ncclSystemError; + } + + g_mesh_state.initialized = 1; + MESH_INFO("Mesh plugin initialized with %d NICs", g_mesh_state.num_nics); + + return ncclSuccess; +} + +static ncclResult_t mesh_devices(int *ndev) { + *ndev = g_mesh_state.num_nics; + return ncclSuccess; +} + +static ncclResult_t mesh_getProperties(int dev, ncclNetProperties_v8_t *props) { + if (dev < 0 || dev >= g_mesh_state.num_nics) { + return ncclInvalidArgument; + } + + struct mesh_nic *nic = &g_mesh_state.nics[dev]; + + memset(props, 0, sizeof(*props)); + props->name = nic->dev_name; + props->pciPath = nic->pci_path; + props->guid = 0; // TODO: Get actual GUID + props->ptrSupport = NCCL_PTR_HOST; // Only host memory for now (no GPUDirect RDMA) + props->speed = 100000; // 100 Gbps + props->port = nic->port_num; + props->latency = 1.0; + props->maxComms = nic->max_qp; + props->maxRecvs = 1; + props->netDeviceType = NCCL_NET_DEVICE_HOST; + props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + props->maxP2pBytes = NCCL_MAX_NET_SIZE_BYTES; + + return ncclSuccess; +} + +static ncclResult_t mesh_listen(int dev, void *handle, void **listenComm) { + (void)dev; // We listen on ALL NICs, not just the requested one + + struct mesh_handle *h = (struct mesh_handle *)handle; + struct mesh_listen_comm *comm; + union ibv_gid gid; + + // Allocate listen comm + comm = calloc(1, sizeof(*comm)); + if (!comm) { + return ncclSystemError; + } + + comm->num_qps = 0; + comm->psn = 0; + comm->handshake_sock = -1; + + // Create QP on EACH NIC + for (int i = 0; i < g_mesh_state.num_nics && i < MESH_MAX_NICS; i++) { + struct mesh_nic *nic = &g_mesh_state.nics[i]; + struct ibv_qp *qp = NULL; + struct ibv_cq *cq = NULL; + + if (mesh_create_qp(nic, &qp, &cq) != 0) { + MESH_WARN("Failed to create QP on NIC %s, skipping", nic->dev_name); + continue; + } + + comm->qps[comm->num_qps].nic = nic; + comm->qps[comm->num_qps].qp = qp; + comm->qps[comm->num_qps].cq = cq; + comm->num_qps++; + + char ip_str[INET_ADDRSTRLEN]; + mesh_uint_to_ip(nic->ip_addr, ip_str, sizeof(ip_str)); + MESH_INFO("listen: Created QP %d on %s (IP=%s)", qp->qp_num, nic->dev_name, ip_str); + } + + if (comm->num_qps == 0) { + MESH_WARN("Failed to create any QPs"); + free(comm); + return ncclSystemError; + } + + // Create handshake socket - bind to all interfaces so any peer can reach us + comm->handshake_ip = INADDR_ANY; + comm->handshake_sock = mesh_create_handshake_socket(INADDR_ANY, &comm->handshake_port); + if (comm->handshake_sock < 0) { + MESH_WARN("Failed to create handshake socket"); + // Continue without handshake - will fail at accept + } + + // Initialize handshake queue and thread + pthread_mutex_init(&comm->queue_mutex, NULL); + pthread_cond_init(&comm->queue_cond, NULL); + comm->queue_head = 0; + comm->queue_tail = 0; + comm->thread_stop = 0; + comm->thread_running = 0; + + // Start handshake thread + if (comm->handshake_sock >= 0) { + if (pthread_create(&comm->handshake_thread, NULL, handshake_thread_func, comm) == 0) { + comm->thread_running = 1; + fprintf(stderr, "MESH DEBUG: listen: Started handshake thread\n"); + fflush(stderr); + } else { + MESH_WARN("Failed to start handshake thread"); + } + } + + // Fill handle with ALL our addresses + memset(h, 0, sizeof(*h)); + h->magic = MESH_HANDLE_MAGIC; + h->num_addrs = 0; + h->psn = comm->psn; + h->port_num = 1; + h->mtu = IBV_MTU_4096; + h->handshake_port = comm->handshake_port; + // Store first NIC IP in handle - but connector will use selected_addr->ip for handshake + h->handshake_ip = htonl(comm->qps[0].nic->ip_addr); + + // Get GID from first NIC for the primary GID field + struct mesh_nic *primary_nic = comm->qps[0].nic; + if (ibv_query_gid(primary_nic->context, primary_nic->port_num, primary_nic->gid_index, &gid) == 0) { + h->gid = gid; + } + + // Add all NIC addresses to the handle + for (int i = 0; i < comm->num_qps && h->num_addrs < MESH_MAX_ADDRS; i++) { + struct mesh_nic *nic = comm->qps[i].nic; + struct mesh_addr_entry *entry = &h->addrs[h->num_addrs]; + + entry->ip = htonl(nic->ip_addr); + entry->mask = htonl(nic->netmask); + entry->qp_num = comm->qps[i].qp->qp_num; + entry->nic_idx = i; + entry->gid_index = nic->gid_index; + + char ip_str[INET_ADDRSTRLEN]; + mesh_uint_to_ip(nic->ip_addr, ip_str, sizeof(ip_str)); + MESH_INFO("listen: Advertising address %d: %s (QP %d)", + h->num_addrs, ip_str, entry->qp_num); + + h->num_addrs++; + } + + MESH_INFO("listen: Ready with %d addresses on %d QPs, handshake port %d", + h->num_addrs, comm->num_qps, comm->handshake_port); + + *listenComm = comm; + return ncclSuccess; +} + +/* + * connect() - THE KEY FUNCTION + * + * Search through peer's advertised addresses to find one on a subnet we can reach + */ +static ncclResult_t mesh_connect(int dev, void *opaqueHandle, void **sendComm, + ncclNetDeviceHandle_t **sendDevComm) { + (void)dev; // We pick the right NIC based on subnet match + + struct mesh_handle *handle = (struct mesh_handle *)opaqueHandle; + struct mesh_send_comm *comm; + struct mesh_nic *nic = NULL; + struct mesh_addr_entry *selected_addr = NULL; + + // Validate handle + if (handle->magic != MESH_HANDLE_MAGIC) { + MESH_WARN("Invalid handle magic: 0x%x", handle->magic); + return ncclInvalidArgument; + } + + MESH_INFO("connect: Peer advertised %d addresses", handle->num_addrs); + + // Search through peer's addresses to find one we can reach + for (int i = 0; i < handle->num_addrs; i++) { + struct mesh_addr_entry *addr = &handle->addrs[i]; + uint32_t peer_ip = ntohl(addr->ip); + + char ip_str[INET_ADDRSTRLEN]; + mesh_uint_to_ip(peer_ip, ip_str, sizeof(ip_str)); + MESH_DEBUG("connect: Checking peer address %d: %s", i, ip_str); + + // Find local NIC on same subnet + nic = mesh_find_nic_for_ip(peer_ip); + if (nic) { + selected_addr = addr; + MESH_INFO("connect: Found matching NIC %s for peer %s", nic->dev_name, ip_str); + break; + } + } + + if (!nic || !selected_addr) { + MESH_WARN("connect: No local NIC found on same subnet as any peer address"); + for (int i = 0; i < handle->num_addrs; i++) { + char ip_str[INET_ADDRSTRLEN]; + mesh_uint_to_ip(ntohl(handle->addrs[i].ip), ip_str, sizeof(ip_str)); + MESH_WARN(" Peer address %d: %s", i, ip_str); + } + return ncclSystemError; + } + + char peer_ip_str[INET_ADDRSTRLEN]; + mesh_uint_to_ip(ntohl(selected_addr->ip), peer_ip_str, sizeof(peer_ip_str)); + + // Allocate send comm + comm = calloc(1, sizeof(*comm)); + if (!comm) { + return ncclSystemError; + } + + comm->nic = nic; + + // Create QP on the selected NIC + if (mesh_create_qp(nic, &comm->qp, &comm->cq) != 0) { + free(comm); + return ncclSystemError; + } + + fprintf(stderr, "MESH DEBUG: connect created QP %d on NIC %s\n", comm->qp->qp_num, nic->dev_name); + fflush(stderr); + + // Do handshake FIRST to get accept's QP number + struct mesh_qp_info remote_qp_info; + memset(&remote_qp_info, 0, sizeof(remote_qp_info)); + + if (handle->handshake_port > 0) { + fprintf(stderr, "MESH DEBUG: connect doing bidirectional handshake\n"); + fflush(stderr); + + struct mesh_qp_info local_info; + memset(&local_info, 0, sizeof(local_info)); + local_info.qp_num = htonl(comm->qp->qp_num); + local_info.psn = htonl(0); // Our PSN + local_info.ip = htonl(nic->ip_addr); + local_info.gid_index = nic->gid_index; + local_info.nic_idx = selected_addr->nic_idx; // Which of listener's NICs we want + + // Copy our GID + union ibv_gid our_gid; + if (ibv_query_gid(nic->context, nic->port_num, nic->gid_index, &our_gid) == 0) { + memcpy(local_info.gid, our_gid.raw, 16); + } + + // Bidirectional handshake - send our info, receive accept's info + uint32_t handshake_ip = ntohl(selected_addr->ip); + + char hs_ip_str[INET_ADDRSTRLEN]; + mesh_uint_to_ip(handshake_ip, hs_ip_str, sizeof(hs_ip_str)); + fprintf(stderr, "MESH DEBUG: Sending handshake to %s:%d\n", hs_ip_str, handle->handshake_port); + fflush(stderr); + + if (mesh_send_handshake(handshake_ip, handle->handshake_port, &local_info, &remote_qp_info) != 0) { + MESH_WARN("connect: Bidirectional handshake failed"); + fprintf(stderr, "MESH DEBUG: Handshake FAILED\n"); + ibv_destroy_qp(comm->qp); + ibv_destroy_cq(comm->cq); + free(comm); + return ncclSystemError; + } + + fprintf(stderr, "MESH DEBUG: Handshake complete! Remote QP=%u\n", ntohl(remote_qp_info.qp_num)); + fflush(stderr); + } else { + MESH_WARN("connect: No handshake port - using listen QP (will likely fail)"); + remote_qp_info.qp_num = htonl(selected_addr->qp_num); + remote_qp_info.psn = htonl(handle->psn); + remote_qp_info.ip = selected_addr->ip; + } + + // Now connect our QP to the ACCEPT's QP (from handshake response) + struct mesh_handle connect_handle; + memset(&connect_handle, 0, sizeof(connect_handle)); + connect_handle.qp_num = ntohl(remote_qp_info.qp_num); // Accept's new QP! + connect_handle.psn = ntohl(remote_qp_info.psn); + connect_handle.lid = 0; // RoCE uses GID, not LID + connect_handle.port_num = nic->port_num; + connect_handle.mtu = IBV_MTU_4096; + + // Construct peer GID from their IP + union ibv_gid peer_gid; + memset(&peer_gid, 0, sizeof(peer_gid)); + peer_gid.raw[10] = 0xff; + peer_gid.raw[11] = 0xff; + uint32_t remote_ip_for_gid = remote_qp_info.ip; // Already in network byte order from handshake + if (remote_ip_for_gid == 0) { + remote_ip_for_gid = selected_addr->ip; // Fallback + } + memcpy(&peer_gid.raw[12], &remote_ip_for_gid, 4); + connect_handle.gid = peer_gid; + + fprintf(stderr, "MESH DEBUG: connect transitioning QP to connect to remote QP %d\n", connect_handle.qp_num); + fflush(stderr); + + // Connect QP to remote + if (mesh_connect_qp(comm->qp, nic, &connect_handle) != 0) { + MESH_WARN("connect: Failed to connect QP to peer"); + ibv_destroy_qp(comm->qp); + ibv_destroy_cq(comm->cq); + free(comm); + return ncclSystemError; + } + + comm->connected = 1; + comm->remote_qp_num = connect_handle.qp_num; + + MESH_INFO("connect: Connected to peer %s via NIC %s (local QP %d -> remote QP %d)", + peer_ip_str, nic->dev_name, comm->qp->qp_num, connect_handle.qp_num); + + fprintf(stderr, "MESH DEBUG: connect returning SUCCESS, comm=%p\n", (void*)comm); + fflush(stderr); + + *sendComm = comm; + if (sendDevComm) *sendDevComm = NULL; + return ncclSuccess; +} + +static ncclResult_t mesh_accept(void *listenComm, void **recvComm, + ncclNetDeviceHandle_t **recvDevComm) { + struct mesh_listen_comm *lcomm = (struct mesh_listen_comm *)listenComm; + struct mesh_recv_comm *rcomm; + + fprintf(stderr, "MESH DEBUG: mesh_accept called, thread_running=%d\n", lcomm->thread_running); + fflush(stderr); + + // Allocate recv comm + rcomm = calloc(1, sizeof(*rcomm)); + if (!rcomm) { + return ncclSystemError; + } + + // Wait for handshake from queue (filled by handshake thread) + pthread_mutex_lock(&lcomm->queue_mutex); + + // Wait with timeout for entry in queue + struct timespec timeout; + clock_gettime(CLOCK_REALTIME, &timeout); + timeout.tv_sec += 30; // 30 second timeout + + while (lcomm->queue_head == lcomm->queue_tail) { + fprintf(stderr, "MESH DEBUG: accept: waiting for handshake in queue...\n"); + fflush(stderr); + int rc = pthread_cond_timedwait(&lcomm->queue_cond, &lcomm->queue_mutex, &timeout); + if (rc == ETIMEDOUT) { + pthread_mutex_unlock(&lcomm->queue_mutex); + MESH_WARN("accept: Timeout waiting for handshake"); + free(rcomm); + return ncclSystemError; + } + } + + // Get entry from queue + struct handshake_entry *entry = &lcomm->handshake_queue[lcomm->queue_head]; + lcomm->queue_head = (lcomm->queue_head + 1) % HANDSHAKE_QUEUE_SIZE; + + // Copy data out + rcomm->qp = entry->local_qp; + rcomm->cq = entry->local_cq; + rcomm->nic = entry->nic; + entry->valid = 0; + + pthread_mutex_unlock(&lcomm->queue_mutex); + + fprintf(stderr, "MESH DEBUG: accept: Got handshake from queue - QP=%d\n", rcomm->qp->qp_num); + fflush(stderr); + + rcomm->connected = 1; + + MESH_INFO("accept: Ready on %s (QP %d)", rcomm->nic->dev_name, rcomm->qp->qp_num); + + *recvComm = rcomm; + if (recvDevComm) *recvDevComm = NULL; + + fprintf(stderr, "MESH DEBUG: accept returning SUCCESS\n"); + fflush(stderr); + + return ncclSuccess; +} + +static ncclResult_t mesh_regMr(void *comm, void *data, size_t size, int type, void **mhandle) { + fprintf(stderr, "MESH DEBUG: regMr ENTRY comm=%p, data=%p, size=%zu, type=%d\n", + comm, data, size, type); + fflush(stderr); + + struct mesh_send_comm *scomm = (struct mesh_send_comm *)comm; + struct mesh_mr_handle *mrh; + int access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; + + fprintf(stderr, "MESH DEBUG: regMr after cast, scomm=%p\n", (void*)scomm); + fflush(stderr); + + if (!scomm || !scomm->nic || !scomm->nic->pd) { + MESH_WARN("regMr: invalid comm or nic"); + fprintf(stderr, "MESH DEBUG: regMr invalid - scomm=%p\n", (void*)scomm); + if (scomm) fprintf(stderr, "MESH DEBUG: scomm->nic=%p\n", (void*)scomm->nic); + return ncclSystemError; + } + + fprintf(stderr, "MESH DEBUG: regMr nic=%s, pd=%p\n", scomm->nic->dev_name, (void*)scomm->nic->pd); + fflush(stderr); + + mrh = calloc(1, sizeof(*mrh)); + if (!mrh) { + return ncclSystemError; + } + + mrh->mr = ibv_reg_mr(scomm->nic->pd, data, size, access_flags); + if (!mrh->mr) { + MESH_WARN("Failed to register MR: %s", strerror(errno)); + fprintf(stderr, "MESH DEBUG: ibv_reg_mr failed: %s\n", strerror(errno)); + free(mrh); + return ncclSystemError; + } + + fprintf(stderr, "MESH DEBUG: regMr success, mr=%p, lkey=%u\n", (void*)mrh->mr, mrh->mr->lkey); + fprintf(stderr, "MESH DEBUG: regMr returning mhandle=%p (mrh->mr=%p)\n", (void*)mrh, (void*)mrh->mr); + fflush(stderr); + + mrh->nic = scomm->nic; + mrh->addr = data; + mrh->size = size; + + *mhandle = mrh; + return ncclSuccess; +} + +static ncclResult_t mesh_regMrDmaBuf(void *comm, void *data, size_t size, int type, + uint64_t offset, int fd, void **mhandle) { + // DMA-BUF not implemented yet + return mesh_regMr(comm, data, size, type, mhandle); +} + +static ncclResult_t mesh_deregMr(void *comm, void *mhandle) { + struct mesh_mr_handle *mrh = (struct mesh_mr_handle *)mhandle; + + if (mrh && mrh->mr) { + ibv_dereg_mr(mrh->mr); + } + free(mrh); + + return ncclSuccess; +} + +static ncclResult_t mesh_isend(void *sendComm, void *data, int size, int tag, + void *mhandle, void **request) { + struct mesh_send_comm *comm = (struct mesh_send_comm *)sendComm; + struct mesh_mr_handle *mrh = (struct mesh_mr_handle *)mhandle; + struct mesh_request *req; + struct ibv_send_wr wr, *bad_wr; + struct ibv_sge sge; + + (void)tag; + + fprintf(stderr, "MESH DEBUG: isend called, comm=%p, data=%p, size=%d, mhandle=%p\n", + (void*)comm, data, size, (void*)mhandle); + if (comm) fprintf(stderr, "MESH DEBUG: isend comm->qp=%p, comm->cq=%p\n", (void*)comm->qp, (void*)comm->cq); + if (mrh) fprintf(stderr, "MESH DEBUG: isend mrh->mr=%p\n", (void*)mrh->mr); + fflush(stderr); + + if (!comm || !comm->qp) { + MESH_WARN("isend: invalid comm"); + return ncclSystemError; + } + if (!mrh || !mrh->mr) { + MESH_WARN("isend: invalid mhandle"); + return ncclSystemError; + } + + req = calloc(1, sizeof(*req)); + if (!req) { + return ncclSystemError; + } + + req->used = 1; + req->size = size; + req->cq = comm->cq; // Store CQ for polling + req->done = 0; + + // Setup scatter/gather entry + sge.addr = (uintptr_t)data; + sge.length = size; + sge.lkey = mrh->mr->lkey; + + fprintf(stderr, "MESH DEBUG: isend sge setup, checking PDs\n"); + fprintf(stderr, "MESH DEBUG: isend comm->nic->pd=%p, mrh->nic->pd=%p\n", + (void*)(comm->nic ? comm->nic->pd : NULL), + (void*)(mrh->nic ? mrh->nic->pd : NULL)); + if (comm->nic && mrh->nic && comm->nic->pd != mrh->nic->pd) { + fprintf(stderr, "MESH DEBUG: ERROR - isend PD MISMATCH!\n"); + } + fflush(stderr); + + // Setup send work request + memset(&wr, 0, sizeof(wr)); + wr.wr_id = (uintptr_t)req; + wr.next = NULL; + wr.sg_list = &sge; + wr.num_sge = 1; + wr.opcode = IBV_WR_SEND; + wr.send_flags = IBV_SEND_SIGNALED; + + fprintf(stderr, "MESH DEBUG: isend about to call ibv_post_send\n"); + fflush(stderr); + + if (ibv_post_send(comm->qp, &wr, &bad_wr)) { + MESH_WARN("Failed to post send: %s", strerror(errno)); + fprintf(stderr, "MESH DEBUG: ibv_post_send FAILED: %s\n", strerror(errno)); + free(req); + return ncclSystemError; + } + + fprintf(stderr, "MESH DEBUG: ibv_post_send succeeded!\n"); + fflush(stderr); + + *request = req; + return ncclSuccess; +} + +static ncclResult_t mesh_irecv(void *recvComm, int n, void **data, int *sizes, + int *tags, void **mhandles, void **request) { + struct mesh_recv_comm *comm = (struct mesh_recv_comm *)recvComm; + struct mesh_request *req; + struct ibv_recv_wr wr, *bad_wr; + struct ibv_sge sge; + + (void)tags; + + fprintf(stderr, "MESH DEBUG: irecv called, comm=%p, n=%d\n", (void*)comm, n); + if (comm) fprintf(stderr, "MESH DEBUG: irecv comm->qp=%p, comm->cq=%p\n", (void*)comm->qp, (void*)comm->cq); + fflush(stderr); + + if (!comm || !comm->qp) { + MESH_WARN("irecv: invalid comm"); + return ncclSystemError; + } + + if (n != 1) { + // For simplicity, only handle n=1 for now + MESH_WARN("irecv with n=%d not supported yet", n); + return ncclInternalError; + } + + struct mesh_mr_handle *mrh = (struct mesh_mr_handle *)mhandles[0]; + fprintf(stderr, "MESH DEBUG: irecv mrh=%p, data[0]=%p, sizes[0]=%d\n", (void*)mrh, data[0], sizes[0]); + + // Check if data address looks like GPU memory (high address space) + uintptr_t data_addr = (uintptr_t)data[0]; + if (data_addr > 0x100000000ULL && data_addr < 0x800000000000ULL) { + fprintf(stderr, "MESH DEBUG: WARNING - data address %p looks like GPU memory!\n", data[0]); + } + + if (mrh) fprintf(stderr, "MESH DEBUG: irecv mrh->mr=%p\n", (void*)mrh->mr); + fflush(stderr); + + if (!mrh || !mrh->mr) { + MESH_WARN("irecv: invalid mhandle"); + return ncclSystemError; + } + + fprintf(stderr, "MESH DEBUG: irecv about to access lkey\n"); + fflush(stderr); + uint32_t lkey = mrh->mr->lkey; + fprintf(stderr, "MESH DEBUG: irecv lkey=%u\n", lkey); + fflush(stderr); + + req = calloc(1, sizeof(*req)); + if (!req) { + return ncclSystemError; + } + + fprintf(stderr, "MESH DEBUG: irecv req allocated=%p\n", (void*)req); + fflush(stderr); + + req->used = 1; + req->size = sizes[0]; + req->cq = comm->cq; // Store CQ for polling + req->done = 0; + + // Setup scatter/gather entry + sge.addr = (uintptr_t)data[0]; + sge.length = sizes[0]; + sge.lkey = lkey; + + fprintf(stderr, "MESH DEBUG: irecv sge setup done, about to post_recv\n"); + fprintf(stderr, "MESH DEBUG: irecv qp=%p\n", (void*)comm->qp); + fprintf(stderr, "MESH DEBUG: irecv comm->nic=%p, comm->nic->pd=%p\n", (void*)comm->nic, (void*)(comm->nic ? comm->nic->pd : NULL)); + fprintf(stderr, "MESH DEBUG: irecv mrh->nic=%p, mrh->nic->pd=%p\n", (void*)mrh->nic, (void*)(mrh->nic ? mrh->nic->pd : NULL)); + + // Check if PDs match! + if (comm->nic && mrh->nic && comm->nic->pd != mrh->nic->pd) { + fprintf(stderr, "MESH DEBUG: ERROR - PD MISMATCH! QP PD != MR PD\n"); + } + fflush(stderr); + + // Skip QP query - just try the post directly + fprintf(stderr, "MESH DEBUG: irecv skipping QP query, going straight to post_recv\n"); + fprintf(stderr, "MESH DEBUG: irecv about to call ibv_post_recv, qp=%p\n", (void*)comm->qp); + fflush(stderr); + + // Setup receive work request + memset(&wr, 0, sizeof(wr)); + wr.wr_id = (uintptr_t)req; + wr.next = NULL; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (ibv_post_recv(comm->qp, &wr, &bad_wr)) { + MESH_WARN("Failed to post recv: %s", strerror(errno)); + fprintf(stderr, "MESH DEBUG: ibv_post_recv FAILED: %s\n", strerror(errno)); + free(req); + return ncclSystemError; + } + + fprintf(stderr, "MESH DEBUG: ibv_post_recv succeeded!\n"); + fflush(stderr); + + *request = req; + return ncclSuccess; +} + +static ncclResult_t mesh_iflush(void *recvComm, int n, void **data, int *sizes, + void **mhandles, void **request) { + // No flush needed for verbs + *request = NULL; + return ncclSuccess; +} + +static ncclResult_t mesh_test(void *request, int *done, int *sizes) { + struct mesh_request *req = (struct mesh_request *)request; + struct ibv_wc wc; + int ret; + + if (!req) { + *done = 1; + return ncclSuccess; + } + + if (req->done) { + *done = 1; + if (sizes) *sizes = req->size; + return ncclSuccess; + } + + // Actually poll the completion queue + if (!req->cq) { + MESH_WARN("mesh_test: request has no CQ"); + req->done = 1; + *done = 1; + return ncclSuccess; + } + + ret = ibv_poll_cq(req->cq, 1, &wc); + if (ret < 0) { + MESH_WARN("mesh_test: ibv_poll_cq failed: %s", strerror(errno)); + return ncclSystemError; + } + + if (ret == 0) { + // No completion yet + *done = 0; + return ncclSuccess; + } + + // Got a completion + if (wc.status != IBV_WC_SUCCESS) { + MESH_WARN("mesh_test: WC error: status=%d (%s)", + wc.status, ibv_wc_status_str(wc.status)); + return ncclSystemError; + } + + // Mark request as done + req->done = 1; + req->wc = wc; + *done = 1; + if (sizes) *sizes = req->size; + + return ncclSuccess; +} + +static ncclResult_t mesh_closeSend(void *sendComm) { + struct mesh_send_comm *comm = (struct mesh_send_comm *)sendComm; + + if (comm) { + if (comm->qp) ibv_destroy_qp(comm->qp); + if (comm->cq) ibv_destroy_cq(comm->cq); + free(comm); + } + + return ncclSuccess; +} + +static ncclResult_t mesh_closeRecv(void *recvComm) { + struct mesh_recv_comm *comm = (struct mesh_recv_comm *)recvComm; + + if (comm) { + // QP/CQ are now owned by recv_comm, destroy them + if (comm->qp) ibv_destroy_qp(comm->qp); + if (comm->cq) ibv_destroy_cq(comm->cq); + free(comm); + } + + return ncclSuccess; +} + +static ncclResult_t mesh_closeListen(void *listenComm) { + struct mesh_listen_comm *comm = (struct mesh_listen_comm *)listenComm; + + if (comm) { + // Stop handshake thread + if (comm->thread_running) { + comm->thread_stop = 1; + pthread_cond_broadcast(&comm->queue_cond); + pthread_join(comm->handshake_thread, NULL); + comm->thread_running = 0; + } + + // Close handshake socket + if (comm->handshake_sock >= 0) { + close(comm->handshake_sock); + } + + // Destroy mutex and condition + pthread_mutex_destroy(&comm->queue_mutex); + pthread_cond_destroy(&comm->queue_cond); + + // Clean up any remaining queue entries + for (int i = 0; i < HANDSHAKE_QUEUE_SIZE; i++) { + if (comm->handshake_queue[i].valid) { + if (comm->handshake_queue[i].local_qp) + ibv_destroy_qp(comm->handshake_queue[i].local_qp); + if (comm->handshake_queue[i].local_cq) + ibv_destroy_cq(comm->handshake_queue[i].local_cq); + } + } + + for (int i = 0; i < comm->num_qps; i++) { + if (comm->qps[i].qp) ibv_destroy_qp(comm->qps[i].qp); + if (comm->qps[i].cq) ibv_destroy_cq(comm->qps[i].cq); + } + free(comm); + } + + return ncclSuccess; +} + +static ncclResult_t mesh_getDeviceMr(void *comm, void *mhandle, void **dptr_mhandle) { + *dptr_mhandle = NULL; + return ncclSuccess; +} + +static ncclResult_t mesh_irecvConsumed(void *recvComm, int n, void *request) { + return ncclSuccess; +} + +/* + * ============================================================================ + * Plugin Export + * ============================================================================ + */ + +__attribute__((visibility("default"))) +const ncclNet_v8_t ncclNetPlugin_v8 = { + .name = PLUGIN_NAME, + .init = mesh_init, + .devices = mesh_devices, + .getProperties = mesh_getProperties, + .listen = mesh_listen, + .connect = mesh_connect, + .accept = mesh_accept, + .regMr = mesh_regMr, + .regMrDmaBuf = mesh_regMrDmaBuf, + .deregMr = mesh_deregMr, + .isend = mesh_isend, + .irecv = mesh_irecv, + .iflush = mesh_iflush, + .test = mesh_test, + .closeSend = mesh_closeSend, + .closeRecv = mesh_closeRecv, + .closeListen = mesh_closeListen, + .getDeviceMr = mesh_getDeviceMr, + .irecvConsumed = mesh_irecvConsumed, +}; + +// Alias for NCCL to find +__attribute__((visibility("default"))) +const ncclNet_v8_t *ncclNet_v8 = &ncclNetPlugin_v8;