mirror of
https://github.com/autoscriptlabs/nccl-mesh-plugin.git
synced 2026-01-11 11:34:06 +00:00
- Enables NCCL over multi-subnet mesh topologies - 8+ GB/s bandwidth over 100Gbps RDMA - Successfully tested with distributed LLM inference (Mistral-7B) - Custom subnet-aware NIC selection - Background handshake thread for deadlock-free connection setup
87 lines
2.6 KiB
Python
87 lines
2.6 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Bandwidth benchmark for NCCL Mesh Plugin
|
|
|
|
Usage:
|
|
# On each node (adjust --rank):
|
|
python benchmark_bandwidth.py --rank 0 --world-size 3 --master-ip 10.0.0.170
|
|
"""
|
|
|
|
import argparse
|
|
import time
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
|
|
def benchmark_allreduce(size_mb: int, iterations: int, warmup: int = 5):
|
|
"""Benchmark all-reduce bandwidth"""
|
|
|
|
# Create tensor
|
|
num_elements = (size_mb * 1024 * 1024) // 4 # float32 = 4 bytes
|
|
tensor = torch.ones(num_elements, device='cuda', dtype=torch.float32)
|
|
|
|
# Warmup
|
|
for _ in range(warmup):
|
|
dist.all_reduce(tensor)
|
|
torch.cuda.synchronize()
|
|
|
|
# Benchmark
|
|
start = time.perf_counter()
|
|
for _ in range(iterations):
|
|
dist.all_reduce(tensor)
|
|
torch.cuda.synchronize()
|
|
elapsed = time.perf_counter() - start
|
|
|
|
# Calculate bandwidth
|
|
# All-reduce transfers 2*(N-1)/N * size data in ring algorithm
|
|
total_data_gb = (size_mb * iterations) / 1024
|
|
bandwidth_gbs = total_data_gb / elapsed
|
|
|
|
return bandwidth_gbs, elapsed
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Benchmark NCCL bandwidth')
|
|
parser.add_argument('--rank', type=int, required=True)
|
|
parser.add_argument('--world-size', type=int, default=3)
|
|
parser.add_argument('--master-ip', type=str, default='10.0.0.170')
|
|
parser.add_argument('--master-port', type=int, default=29500)
|
|
parser.add_argument('--iterations', type=int, default=20)
|
|
args = parser.parse_args()
|
|
|
|
# Initialize
|
|
init_method = f'tcp://{args.master_ip}:{args.master_port}'
|
|
dist.init_process_group('nccl', rank=args.rank, world_size=args.world_size,
|
|
init_method=init_method)
|
|
|
|
if args.rank == 0:
|
|
print(f'\n{"="*60}')
|
|
print(f'NCCL Mesh Plugin Bandwidth Benchmark')
|
|
print(f'World size: {args.world_size}')
|
|
print(f'Iterations per size: {args.iterations}')
|
|
print(f'{"="*60}\n')
|
|
print(f'{"Size":<12} {"Bandwidth":<15} {"Time":<12}')
|
|
print(f'{"-"*12} {"-"*15} {"-"*12}')
|
|
|
|
# Test different sizes
|
|
sizes_mb = [1, 4, 16, 64, 128, 256, 512]
|
|
|
|
for size_mb in sizes_mb:
|
|
bandwidth, elapsed = benchmark_allreduce(size_mb, args.iterations)
|
|
|
|
if args.rank == 0:
|
|
print(f'{size_mb:>6} MB {bandwidth:>8.2f} GB/s {elapsed:>6.3f} s')
|
|
|
|
# Sync between sizes
|
|
dist.barrier()
|
|
|
|
if args.rank == 0:
|
|
print(f'\n{"="*60}')
|
|
print('Benchmark complete!')
|
|
print(f'{"="*60}\n')
|
|
|
|
dist.destroy_process_group()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|