mirror of
https://github.com/autoscriptlabs/nccl-mesh-plugin.git
synced 2026-01-11 11:34:06 +00:00
Initial release: NCCL Mesh Plugin for direct-connect RDMA topologies
- Enables NCCL over multi-subnet mesh topologies - 8+ GB/s bandwidth over 100Gbps RDMA - Successfully tested with distributed LLM inference (Mistral-7B) - Custom subnet-aware NIC selection - Background handshake thread for deadlock-free connection setup
This commit is contained in:
commit
031bc48953
13 changed files with 3074 additions and 0 deletions
257
include/mesh_plugin.h
Normal file
257
include/mesh_plugin.h
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
/*
|
||||
* NCCL Mesh Plugin - Subnet-aware RDMA transport
|
||||
*
|
||||
* Enables NCCL to work with direct-connect mesh topologies where
|
||||
* each node pair is on a different subnet.
|
||||
*/
|
||||
|
||||
#ifndef NCCL_MESH_PLUGIN_H
|
||||
#define NCCL_MESH_PLUGIN_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <pthread.h>
|
||||
#include <infiniband/verbs.h>
|
||||
|
||||
#define MESH_MAX_NICS 8
|
||||
#define MESH_MAX_QPS 256
|
||||
#define MESH_MAX_MRS 1024
|
||||
#define MESH_HANDLE_MAGIC 0x4D455348 // "MESH"
|
||||
|
||||
// Forward declarations
|
||||
struct mesh_plugin_state;
|
||||
struct mesh_nic;
|
||||
struct mesh_comm;
|
||||
|
||||
/*
|
||||
* Represents one RDMA-capable NIC with its subnet information
|
||||
*/
|
||||
struct mesh_nic {
|
||||
// RDMA resources
|
||||
struct ibv_context *context;
|
||||
struct ibv_pd *pd;
|
||||
int port_num;
|
||||
int gid_index;
|
||||
|
||||
// Network addressing
|
||||
uint32_t ip_addr; // Host byte order
|
||||
uint32_t netmask; // Host byte order
|
||||
uint32_t subnet; // ip_addr & netmask
|
||||
|
||||
// Device identification
|
||||
char dev_name[64]; // RDMA device name (e.g., "rocep1s0f1")
|
||||
char if_name[64]; // Network interface name (e.g., "enp1s0f1np1")
|
||||
char pci_path[256]; // PCI bus path
|
||||
|
||||
// Capabilities
|
||||
int max_qp;
|
||||
int max_cq;
|
||||
int max_mr;
|
||||
int max_sge;
|
||||
uint64_t max_mr_size;
|
||||
int gdr_supported; // GPUDirect RDMA support
|
||||
|
||||
// Statistics
|
||||
uint64_t bytes_sent;
|
||||
uint64_t bytes_recv;
|
||||
uint64_t connections;
|
||||
};
|
||||
|
||||
/*
|
||||
* Address entry for multi-homed hosts
|
||||
*/
|
||||
#define MESH_MAX_ADDRS 6
|
||||
|
||||
struct mesh_addr_entry {
|
||||
uint32_t ip; // IP address (network byte order)
|
||||
uint32_t mask; // Subnet mask (network byte order)
|
||||
uint16_t qp_num; // QP number for this NIC
|
||||
uint8_t nic_idx; // Index into our NIC array
|
||||
uint8_t gid_index; // GID index for this NIC
|
||||
};
|
||||
|
||||
/*
|
||||
* Connection handle - exchanged between peers during setup
|
||||
* Must fit within NCCL_NET_HANDLE_MAXSIZE (128 bytes)
|
||||
*/
|
||||
struct mesh_handle {
|
||||
uint32_t magic; // MESH_HANDLE_MAGIC
|
||||
uint8_t num_addrs; // Number of valid addresses
|
||||
uint8_t selected_idx; // Which address was selected (set by connect)
|
||||
uint16_t lid; // IB LID (0 for RoCE)
|
||||
uint16_t qp_num; // QP number (for compat with mesh_connect_qp)
|
||||
uint16_t handshake_port; // TCP port for QP handshake
|
||||
uint8_t port_num; // Port number (usually 1)
|
||||
uint8_t mtu; // MTU setting
|
||||
uint32_t psn; // Packet sequence number
|
||||
uint32_t handshake_ip; // IP address for handshake (network byte order)
|
||||
union ibv_gid gid; // GID (16 bytes)
|
||||
struct mesh_addr_entry addrs[MESH_MAX_ADDRS]; // 12 bytes each
|
||||
// Total: 4+1+1+2+2+2+1+1+4+4+16 + 6*12 = 38 + 72 = 110 bytes (fits in 128)
|
||||
};
|
||||
|
||||
/*
|
||||
* Listen state - waiting for incoming connections
|
||||
* Creates QPs on ALL NICs so any peer can connect
|
||||
*/
|
||||
#define HANDSHAKE_QUEUE_SIZE 16
|
||||
|
||||
/*
|
||||
* QP info exchanged during handshake
|
||||
*/
|
||||
struct mesh_qp_info {
|
||||
uint32_t qp_num; // Network byte order
|
||||
uint32_t psn; // Network byte order
|
||||
uint8_t gid[16]; // Raw GID
|
||||
uint32_t ip; // Network byte order
|
||||
uint8_t gid_index;
|
||||
uint8_t nic_idx; // Which NIC on the listener
|
||||
uint8_t reserved[2];
|
||||
};
|
||||
|
||||
struct handshake_entry {
|
||||
struct mesh_qp_info remote_info;
|
||||
struct ibv_qp *local_qp;
|
||||
struct ibv_cq *local_cq;
|
||||
struct mesh_nic *nic;
|
||||
int valid;
|
||||
};
|
||||
|
||||
struct mesh_listen_comm {
|
||||
int num_qps;
|
||||
struct {
|
||||
struct mesh_nic *nic;
|
||||
struct ibv_qp *qp;
|
||||
struct ibv_cq *cq;
|
||||
} qps[MESH_MAX_NICS];
|
||||
uint32_t psn;
|
||||
int ready;
|
||||
|
||||
// Handshake socket for QP info exchange
|
||||
int handshake_sock;
|
||||
uint16_t handshake_port;
|
||||
uint32_t handshake_ip;
|
||||
|
||||
// Background handshake thread
|
||||
pthread_t handshake_thread;
|
||||
int thread_running;
|
||||
int thread_stop;
|
||||
|
||||
// Queue of received handshakes for accept() to consume
|
||||
struct handshake_entry handshake_queue[HANDSHAKE_QUEUE_SIZE];
|
||||
int queue_head;
|
||||
int queue_tail;
|
||||
pthread_mutex_t queue_mutex;
|
||||
pthread_cond_t queue_cond;
|
||||
};
|
||||
|
||||
/*
|
||||
* Send/Receive communication state
|
||||
*/
|
||||
struct mesh_send_comm {
|
||||
struct mesh_nic *nic;
|
||||
struct ibv_qp *qp;
|
||||
struct ibv_cq *cq;
|
||||
uint32_t remote_qp_num;
|
||||
union ibv_gid remote_gid;
|
||||
int connected;
|
||||
|
||||
// Request tracking
|
||||
struct mesh_request *requests[MESH_MAX_QPS];
|
||||
int num_requests;
|
||||
};
|
||||
|
||||
struct mesh_recv_comm {
|
||||
struct mesh_nic *nic;
|
||||
struct ibv_qp *qp;
|
||||
struct ibv_cq *cq;
|
||||
int connected;
|
||||
|
||||
// Request tracking
|
||||
struct mesh_request *requests[MESH_MAX_QPS];
|
||||
int num_requests;
|
||||
};
|
||||
|
||||
/*
|
||||
* Memory registration handle
|
||||
*/
|
||||
struct mesh_mr_handle {
|
||||
struct ibv_mr *mr;
|
||||
struct mesh_nic *nic;
|
||||
void *addr;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
/*
|
||||
* Async request state
|
||||
*/
|
||||
struct mesh_request {
|
||||
int used;
|
||||
int done;
|
||||
size_t size;
|
||||
struct ibv_cq *cq; // CQ to poll for completion
|
||||
struct ibv_wc wc;
|
||||
};
|
||||
|
||||
/*
|
||||
* Global plugin state
|
||||
*/
|
||||
struct mesh_plugin_state {
|
||||
struct mesh_nic nics[MESH_MAX_NICS];
|
||||
int num_nics;
|
||||
int initialized;
|
||||
|
||||
// Configuration
|
||||
int gid_index; // From NCCL_MESH_GID_INDEX
|
||||
int debug; // From NCCL_MESH_DEBUG
|
||||
|
||||
// Logging (provided by NCCL)
|
||||
void (*log_fn)(int level, unsigned long flags, const char *file,
|
||||
int line, const char *fmt, ...);
|
||||
};
|
||||
|
||||
// Global state (singleton)
|
||||
extern struct mesh_plugin_state g_mesh_state;
|
||||
|
||||
/*
|
||||
* Internal functions
|
||||
*/
|
||||
|
||||
// Initialization
|
||||
int mesh_init_nics(void);
|
||||
int mesh_discover_nic_ips(void);
|
||||
int mesh_setup_nic(struct mesh_nic *nic, struct ibv_device *device);
|
||||
|
||||
// Routing
|
||||
struct mesh_nic* mesh_find_nic_for_ip(uint32_t peer_ip);
|
||||
struct mesh_nic* mesh_find_nic_by_name(const char *name);
|
||||
int mesh_get_nic_index(struct mesh_nic *nic);
|
||||
|
||||
// RDMA operations
|
||||
int mesh_create_qp(struct mesh_nic *nic, struct ibv_qp **qp, struct ibv_cq **cq);
|
||||
int mesh_connect_qp(struct ibv_qp *qp, struct mesh_nic *nic, struct mesh_handle *remote);
|
||||
int mesh_post_send(struct mesh_send_comm *comm, void *data, size_t size,
|
||||
struct mesh_mr_handle *mr, struct mesh_request *req);
|
||||
int mesh_post_recv(struct mesh_recv_comm *comm, void *data, size_t size,
|
||||
struct mesh_mr_handle *mr, struct mesh_request *req);
|
||||
int mesh_poll_cq(struct ibv_cq *cq, struct mesh_request *req);
|
||||
|
||||
// Utilities
|
||||
uint32_t mesh_ip_to_uint(const char *ip_str);
|
||||
void mesh_uint_to_ip(uint32_t ip, char *buf, size_t len);
|
||||
int mesh_get_interface_ip(const char *if_name, uint32_t *ip, uint32_t *mask);
|
||||
const char* mesh_find_netdev_for_rdma(const char *rdma_dev);
|
||||
|
||||
// Logging macros
|
||||
#define MESH_LOG(level, fmt, ...) \
|
||||
do { \
|
||||
if (g_mesh_state.log_fn) { \
|
||||
g_mesh_state.log_fn(level, 0, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
#define MESH_INFO(fmt, ...) MESH_LOG(NCCL_LOG_INFO, "MESH " fmt, ##__VA_ARGS__)
|
||||
#define MESH_WARN(fmt, ...) MESH_LOG(NCCL_LOG_WARN, "MESH " fmt, ##__VA_ARGS__)
|
||||
#define MESH_DEBUG(fmt, ...) \
|
||||
do { if (g_mesh_state.debug) MESH_LOG(NCCL_LOG_TRACE, "MESH " fmt, ##__VA_ARGS__); } while(0)
|
||||
|
||||
#endif // NCCL_MESH_PLUGIN_H
|
||||
Loading…
Add table
Add a link
Reference in a new issue