- nvfp4_mega_moe_full: drop-in replacement for deep_gemm.mega.fp8_nvfp4_mega_moe - transform_nvfp4_weights_for_mega_moe: weight transformation (tested) - SymmBuffer + get_symm_buffer_for_nvfp4_mega_moe: API-matching stubs - MEGA_MOE_STATIC=1 support for pipeline testing - pyproject.toml for pip install
87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
"""Symmetric buffer for NVLink cross-rank all-reduce in mega_moe.
|
|
|
|
Replaces deep_gemm.mega.SymmBuffer and get_symm_buffer_for_nvfp4_mega_moe.
|
|
API matches the DeepGEMM signature used in the vLLM deepseek_v4.py patch.
|
|
"""
|
|
|
|
import os
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
|
|
|
|
|
|
class SymmBuffer:
|
|
"""Symmetric NVLink buffer for expert-parallel cross-rank communication.
|
|
|
|
Matches the DeepGEMM SymmBuffer interface expected by the vLLM patch:
|
|
- .x: staged activation (FP4 packed)
|
|
- .x_sf: staged activation scales (UE4M3 packed)
|
|
- .topk_idx: top-k expert indices
|
|
- .topk_weights: top-k routing weights
|
|
- .buffer: underlying CUDA buffer
|
|
- .group: process group
|
|
"""
|
|
|
|
def __init__(self, group, num_experts, max_num_tokens, top_k,
|
|
hidden_size, intermediate_size):
|
|
self.group = group
|
|
self.num_experts = num_experts
|
|
self.max_num_tokens = max_num_tokens
|
|
self.top_k = top_k
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
|
|
device = torch.cuda.current_device()
|
|
|
|
# NVFP4: packed E2M1 (2 values per byte), so K//2
|
|
sf_k_groups_hidden = hidden_size // (16 * 4) # UE4M3 block16, 4 packed per uint32
|
|
sf_k_groups_inter = intermediate_size // (16 * 4)
|
|
|
|
# Staging buffers (matching DeepGEMM layout)
|
|
self.x = torch.empty(
|
|
max_num_tokens, hidden_size // 2,
|
|
dtype=torch.int8, device=device,
|
|
)
|
|
self.x_sf = torch.empty(
|
|
max_num_tokens, sf_k_groups_hidden,
|
|
dtype=torch.uint32, device=device,
|
|
)
|
|
self.topk_idx = torch.empty(
|
|
max_num_tokens, top_k,
|
|
dtype=torch.int32, device=device,
|
|
)
|
|
self.topk_weights = torch.empty(
|
|
max_num_tokens, top_k,
|
|
dtype=torch.float32, device=device,
|
|
)
|
|
|
|
# All-reduce buffer
|
|
self.buffer = torch.empty(
|
|
max_num_tokens, hidden_size,
|
|
dtype=torch.bfloat16, device=device,
|
|
)
|
|
|
|
if MEGA_MOE_DEBUG:
|
|
print(f"[SymmBuffer] x={self.x.shape} x_sf={self.x_sf.shape} "
|
|
f"topk_idx={self.topk_idx.shape} topk_weights={self.topk_weights.shape} "
|
|
f"buffer={self.buffer.shape}")
|
|
|
|
|
|
def get_symm_buffer_for_nvfp4_mega_moe(
|
|
group,
|
|
num_experts: int,
|
|
max_num_tokens: int,
|
|
top_k: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
) -> SymmBuffer:
|
|
"""Allocate a symmetric buffer for the NVFP4 mega_moe kernel.
|
|
|
|
API matches deep_gemm.mega.get_symm_buffer_for_nvfp4_mega_moe.
|
|
"""
|
|
return SymmBuffer(
|
|
group, num_experts, max_num_tokens, top_k,
|
|
hidden_size, intermediate_size,
|
|
)
|