Files
nvfp4-megamoe-kernel/src/nvfp4_megamoe_kernel/symm_buffer.py
biondizzle c2b752c2fe Initial: TileLang NVFP4 mega_moe kernel package
- nvfp4_mega_moe_full: drop-in replacement for deep_gemm.mega.fp8_nvfp4_mega_moe
- transform_nvfp4_weights_for_mega_moe: weight transformation (tested)
- SymmBuffer + get_symm_buffer_for_nvfp4_mega_moe: API-matching stubs
- MEGA_MOE_STATIC=1 support for pipeline testing
- pyproject.toml for pip install
2026-05-13 15:44:51 +00:00

87 lines
2.8 KiB
Python

"""Symmetric buffer for NVLink cross-rank all-reduce in mega_moe.
Replaces deep_gemm.mega.SymmBuffer and get_symm_buffer_for_nvfp4_mega_moe.
API matches the DeepGEMM signature used in the vLLM deepseek_v4.py patch.
"""
import os
import torch
import torch.distributed as dist
MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
class SymmBuffer:
"""Symmetric NVLink buffer for expert-parallel cross-rank communication.
Matches the DeepGEMM SymmBuffer interface expected by the vLLM patch:
- .x: staged activation (FP4 packed)
- .x_sf: staged activation scales (UE4M3 packed)
- .topk_idx: top-k expert indices
- .topk_weights: top-k routing weights
- .buffer: underlying CUDA buffer
- .group: process group
"""
def __init__(self, group, num_experts, max_num_tokens, top_k,
hidden_size, intermediate_size):
self.group = group
self.num_experts = num_experts
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
device = torch.cuda.current_device()
# NVFP4: packed E2M1 (2 values per byte), so K//2
sf_k_groups_hidden = hidden_size // (16 * 4) # UE4M3 block16, 4 packed per uint32
sf_k_groups_inter = intermediate_size // (16 * 4)
# Staging buffers (matching DeepGEMM layout)
self.x = torch.empty(
max_num_tokens, hidden_size // 2,
dtype=torch.int8, device=device,
)
self.x_sf = torch.empty(
max_num_tokens, sf_k_groups_hidden,
dtype=torch.uint32, device=device,
)
self.topk_idx = torch.empty(
max_num_tokens, top_k,
dtype=torch.int32, device=device,
)
self.topk_weights = torch.empty(
max_num_tokens, top_k,
dtype=torch.float32, device=device,
)
# All-reduce buffer
self.buffer = torch.empty(
max_num_tokens, hidden_size,
dtype=torch.bfloat16, device=device,
)
if MEGA_MOE_DEBUG:
print(f"[SymmBuffer] x={self.x.shape} x_sf={self.x_sf.shape} "
f"topk_idx={self.topk_idx.shape} topk_weights={self.topk_weights.shape} "
f"buffer={self.buffer.shape}")
def get_symm_buffer_for_nvfp4_mega_moe(
group,
num_experts: int,
max_num_tokens: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
) -> SymmBuffer:
"""Allocate a symmetric buffer for the NVFP4 mega_moe kernel.
API matches deep_gemm.mega.get_symm_buffer_for_nvfp4_mega_moe.
"""
return SymmBuffer(
group, num_experts, max_num_tokens, top_k,
hidden_size, intermediate_size,
)