nccl-mesh-plugin/examples/distributed_llm.py
autoscriptlabs 031bc48953 Initial release: NCCL Mesh Plugin for direct-connect RDMA topologies
- Enables NCCL over multi-subnet mesh topologies
- 8+ GB/s bandwidth over 100Gbps RDMA
- Successfully tested with distributed LLM inference (Mistral-7B)
- Custom subnet-aware NIC selection
- Background handshake thread for deadlock-free connection setup
2026-01-09 14:09:33 -05:00

84 lines
2.7 KiB
Python

#!/usr/bin/env python3
"""
Distributed LLM Inference with NCCL Mesh Plugin
This example demonstrates loading and running inference on a large language
model distributed across multiple GPUs using the NCCL Mesh Plugin.
Usage:
# On each node (adjust --rank):
python distributed_llm.py --rank 0 --world-size 3 --master-ip 10.0.0.170
Environment setup (run on each node):
cd ~/nccl-mesh-plugin
export LD_LIBRARY_PATH=$(pwd):$LD_LIBRARY_PATH
export NCCL_NET_PLUGIN=mesh
export NCCL_DEBUG=WARN
"""
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
def main():
parser = argparse.ArgumentParser(description='Distributed LLM Inference')
parser.add_argument('--rank', type=int, required=True)
parser.add_argument('--world-size', type=int, default=3)
parser.add_argument('--master-ip', type=str, default='10.0.0.170')
parser.add_argument('--master-port', type=int, default=29500)
parser.add_argument('--model', type=str, default='mistralai/Mistral-7B-Instruct-v0.2',
help='Model to load (default: Mistral-7B)')
parser.add_argument('--prompt', type=str,
default='The future of distributed AI computing is',
help='Prompt for generation')
parser.add_argument('--max-tokens', type=int, default=100,
help='Maximum tokens to generate')
args = parser.parse_args()
# Initialize accelerator
accelerator = Accelerator()
print(f'Rank {accelerator.process_index}: Loading tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(args.model)
print(f'Rank {accelerator.process_index}: Loading model...')
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=torch.bfloat16,
device_map='auto',
)
print(f'Rank {accelerator.process_index}: Model loaded!')
# Only rank 0 generates
if accelerator.is_main_process:
print(f'\nGenerating text...')
print(f'Prompt: "{args.prompt}"\n')
inputs = tokenizer(args.prompt, return_tensors='pt').to('cuda')
outputs = model.generate(
**inputs,
max_new_tokens=args.max_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9,
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print('=' * 60)
print('Generated Text:')
print('=' * 60)
print(result)
print('=' * 60)
# Wait for all ranks
accelerator.wait_for_everyone()
print(f'Rank {accelerator.process_index}: Done!')
if __name__ == '__main__':
main()