mirror of
https://github.com/autoscriptlabs/nccl-mesh-plugin.git
synced 2026-01-11 11:34:06 +00:00
- Enables NCCL over multi-subnet mesh topologies - 8+ GB/s bandwidth over 100Gbps RDMA - Successfully tested with distributed LLM inference (Mistral-7B) - Custom subnet-aware NIC selection - Background handshake thread for deadlock-free connection setup
257 lines
7.1 KiB
C
257 lines
7.1 KiB
C
/*
|
|
* NCCL Mesh Plugin - Subnet-aware RDMA transport
|
|
*
|
|
* Enables NCCL to work with direct-connect mesh topologies where
|
|
* each node pair is on a different subnet.
|
|
*/
|
|
|
|
#ifndef NCCL_MESH_PLUGIN_H
|
|
#define NCCL_MESH_PLUGIN_H
|
|
|
|
#include <stdint.h>
|
|
#include <pthread.h>
|
|
#include <infiniband/verbs.h>
|
|
|
|
#define MESH_MAX_NICS 8
|
|
#define MESH_MAX_QPS 256
|
|
#define MESH_MAX_MRS 1024
|
|
#define MESH_HANDLE_MAGIC 0x4D455348 // "MESH"
|
|
|
|
// Forward declarations
|
|
struct mesh_plugin_state;
|
|
struct mesh_nic;
|
|
struct mesh_comm;
|
|
|
|
/*
|
|
* Represents one RDMA-capable NIC with its subnet information
|
|
*/
|
|
struct mesh_nic {
|
|
// RDMA resources
|
|
struct ibv_context *context;
|
|
struct ibv_pd *pd;
|
|
int port_num;
|
|
int gid_index;
|
|
|
|
// Network addressing
|
|
uint32_t ip_addr; // Host byte order
|
|
uint32_t netmask; // Host byte order
|
|
uint32_t subnet; // ip_addr & netmask
|
|
|
|
// Device identification
|
|
char dev_name[64]; // RDMA device name (e.g., "rocep1s0f1")
|
|
char if_name[64]; // Network interface name (e.g., "enp1s0f1np1")
|
|
char pci_path[256]; // PCI bus path
|
|
|
|
// Capabilities
|
|
int max_qp;
|
|
int max_cq;
|
|
int max_mr;
|
|
int max_sge;
|
|
uint64_t max_mr_size;
|
|
int gdr_supported; // GPUDirect RDMA support
|
|
|
|
// Statistics
|
|
uint64_t bytes_sent;
|
|
uint64_t bytes_recv;
|
|
uint64_t connections;
|
|
};
|
|
|
|
/*
|
|
* Address entry for multi-homed hosts
|
|
*/
|
|
#define MESH_MAX_ADDRS 6
|
|
|
|
struct mesh_addr_entry {
|
|
uint32_t ip; // IP address (network byte order)
|
|
uint32_t mask; // Subnet mask (network byte order)
|
|
uint16_t qp_num; // QP number for this NIC
|
|
uint8_t nic_idx; // Index into our NIC array
|
|
uint8_t gid_index; // GID index for this NIC
|
|
};
|
|
|
|
/*
|
|
* Connection handle - exchanged between peers during setup
|
|
* Must fit within NCCL_NET_HANDLE_MAXSIZE (128 bytes)
|
|
*/
|
|
struct mesh_handle {
|
|
uint32_t magic; // MESH_HANDLE_MAGIC
|
|
uint8_t num_addrs; // Number of valid addresses
|
|
uint8_t selected_idx; // Which address was selected (set by connect)
|
|
uint16_t lid; // IB LID (0 for RoCE)
|
|
uint16_t qp_num; // QP number (for compat with mesh_connect_qp)
|
|
uint16_t handshake_port; // TCP port for QP handshake
|
|
uint8_t port_num; // Port number (usually 1)
|
|
uint8_t mtu; // MTU setting
|
|
uint32_t psn; // Packet sequence number
|
|
uint32_t handshake_ip; // IP address for handshake (network byte order)
|
|
union ibv_gid gid; // GID (16 bytes)
|
|
struct mesh_addr_entry addrs[MESH_MAX_ADDRS]; // 12 bytes each
|
|
// Total: 4+1+1+2+2+2+1+1+4+4+16 + 6*12 = 38 + 72 = 110 bytes (fits in 128)
|
|
};
|
|
|
|
/*
|
|
* Listen state - waiting for incoming connections
|
|
* Creates QPs on ALL NICs so any peer can connect
|
|
*/
|
|
#define HANDSHAKE_QUEUE_SIZE 16
|
|
|
|
/*
|
|
* QP info exchanged during handshake
|
|
*/
|
|
struct mesh_qp_info {
|
|
uint32_t qp_num; // Network byte order
|
|
uint32_t psn; // Network byte order
|
|
uint8_t gid[16]; // Raw GID
|
|
uint32_t ip; // Network byte order
|
|
uint8_t gid_index;
|
|
uint8_t nic_idx; // Which NIC on the listener
|
|
uint8_t reserved[2];
|
|
};
|
|
|
|
struct handshake_entry {
|
|
struct mesh_qp_info remote_info;
|
|
struct ibv_qp *local_qp;
|
|
struct ibv_cq *local_cq;
|
|
struct mesh_nic *nic;
|
|
int valid;
|
|
};
|
|
|
|
struct mesh_listen_comm {
|
|
int num_qps;
|
|
struct {
|
|
struct mesh_nic *nic;
|
|
struct ibv_qp *qp;
|
|
struct ibv_cq *cq;
|
|
} qps[MESH_MAX_NICS];
|
|
uint32_t psn;
|
|
int ready;
|
|
|
|
// Handshake socket for QP info exchange
|
|
int handshake_sock;
|
|
uint16_t handshake_port;
|
|
uint32_t handshake_ip;
|
|
|
|
// Background handshake thread
|
|
pthread_t handshake_thread;
|
|
int thread_running;
|
|
int thread_stop;
|
|
|
|
// Queue of received handshakes for accept() to consume
|
|
struct handshake_entry handshake_queue[HANDSHAKE_QUEUE_SIZE];
|
|
int queue_head;
|
|
int queue_tail;
|
|
pthread_mutex_t queue_mutex;
|
|
pthread_cond_t queue_cond;
|
|
};
|
|
|
|
/*
|
|
* Send/Receive communication state
|
|
*/
|
|
struct mesh_send_comm {
|
|
struct mesh_nic *nic;
|
|
struct ibv_qp *qp;
|
|
struct ibv_cq *cq;
|
|
uint32_t remote_qp_num;
|
|
union ibv_gid remote_gid;
|
|
int connected;
|
|
|
|
// Request tracking
|
|
struct mesh_request *requests[MESH_MAX_QPS];
|
|
int num_requests;
|
|
};
|
|
|
|
struct mesh_recv_comm {
|
|
struct mesh_nic *nic;
|
|
struct ibv_qp *qp;
|
|
struct ibv_cq *cq;
|
|
int connected;
|
|
|
|
// Request tracking
|
|
struct mesh_request *requests[MESH_MAX_QPS];
|
|
int num_requests;
|
|
};
|
|
|
|
/*
|
|
* Memory registration handle
|
|
*/
|
|
struct mesh_mr_handle {
|
|
struct ibv_mr *mr;
|
|
struct mesh_nic *nic;
|
|
void *addr;
|
|
size_t size;
|
|
};
|
|
|
|
/*
|
|
* Async request state
|
|
*/
|
|
struct mesh_request {
|
|
int used;
|
|
int done;
|
|
size_t size;
|
|
struct ibv_cq *cq; // CQ to poll for completion
|
|
struct ibv_wc wc;
|
|
};
|
|
|
|
/*
|
|
* Global plugin state
|
|
*/
|
|
struct mesh_plugin_state {
|
|
struct mesh_nic nics[MESH_MAX_NICS];
|
|
int num_nics;
|
|
int initialized;
|
|
|
|
// Configuration
|
|
int gid_index; // From NCCL_MESH_GID_INDEX
|
|
int debug; // From NCCL_MESH_DEBUG
|
|
|
|
// Logging (provided by NCCL)
|
|
void (*log_fn)(int level, unsigned long flags, const char *file,
|
|
int line, const char *fmt, ...);
|
|
};
|
|
|
|
// Global state (singleton)
|
|
extern struct mesh_plugin_state g_mesh_state;
|
|
|
|
/*
|
|
* Internal functions
|
|
*/
|
|
|
|
// Initialization
|
|
int mesh_init_nics(void);
|
|
int mesh_discover_nic_ips(void);
|
|
int mesh_setup_nic(struct mesh_nic *nic, struct ibv_device *device);
|
|
|
|
// Routing
|
|
struct mesh_nic* mesh_find_nic_for_ip(uint32_t peer_ip);
|
|
struct mesh_nic* mesh_find_nic_by_name(const char *name);
|
|
int mesh_get_nic_index(struct mesh_nic *nic);
|
|
|
|
// RDMA operations
|
|
int mesh_create_qp(struct mesh_nic *nic, struct ibv_qp **qp, struct ibv_cq **cq);
|
|
int mesh_connect_qp(struct ibv_qp *qp, struct mesh_nic *nic, struct mesh_handle *remote);
|
|
int mesh_post_send(struct mesh_send_comm *comm, void *data, size_t size,
|
|
struct mesh_mr_handle *mr, struct mesh_request *req);
|
|
int mesh_post_recv(struct mesh_recv_comm *comm, void *data, size_t size,
|
|
struct mesh_mr_handle *mr, struct mesh_request *req);
|
|
int mesh_poll_cq(struct ibv_cq *cq, struct mesh_request *req);
|
|
|
|
// Utilities
|
|
uint32_t mesh_ip_to_uint(const char *ip_str);
|
|
void mesh_uint_to_ip(uint32_t ip, char *buf, size_t len);
|
|
int mesh_get_interface_ip(const char *if_name, uint32_t *ip, uint32_t *mask);
|
|
const char* mesh_find_netdev_for_rdma(const char *rdma_dev);
|
|
|
|
// Logging macros
|
|
#define MESH_LOG(level, fmt, ...) \
|
|
do { \
|
|
if (g_mesh_state.log_fn) { \
|
|
g_mesh_state.log_fn(level, 0, __FILE__, __LINE__, fmt, ##__VA_ARGS__); \
|
|
} \
|
|
} while(0)
|
|
|
|
#define MESH_INFO(fmt, ...) MESH_LOG(NCCL_LOG_INFO, "MESH " fmt, ##__VA_ARGS__)
|
|
#define MESH_WARN(fmt, ...) MESH_LOG(NCCL_LOG_WARN, "MESH " fmt, ##__VA_ARGS__)
|
|
#define MESH_DEBUG(fmt, ...) \
|
|
do { if (g_mesh_state.debug) MESH_LOG(NCCL_LOG_TRACE, "MESH " fmt, ##__VA_ARGS__); } while(0)
|
|
|
|
#endif // NCCL_MESH_PLUGIN_H
|