Files
nvfp4-megamoe-kernel/src/nvfp4_megamoe_kernel/symm_buffer.py
biondizzle 9908fd64d9 feat: CUTLASS NVFP4 mega_moe kernel — slot-based L1/L2, source-first SF remap
Major changes from initial TileLang prototype:

Kernel:
- CUTLASS NVFP4 block-scaled GEMM (SM100 Blackwell, OpClassBlockScaledTensorOp)
- Slot-based dispatch: L1 GEMM → SiLU+Mul per-slot → L2 GEMM → index_add scatter
- 1D slot_expert_ids passed to both L1 and L2 (no 2D topk_ids rebuild)
- slot_token gathered in cutlass_grouped_nvfp4_gemm when provided

SF Remap (source-first):
- Iterates logical (m, k_sf) source grid, uses layout_sf(make_coord(m, k_sf))
  for CUTLASS dest index — no idx2crd/flatten coordinate extraction
- 2D kernel launch: dim3 block(32,8), grid over (K_sf, MN)
- Uses cute::cosize() for physical allocation size (not cute::size)
- SFA: (MN, K_sf) row-major; SFB: (K_sf, MN) row-major (col-major)

Weight transform:
- UE4M3 unpack with bit reinterpret (not value cast)
- Global scale folding (weight_scale_2) for gate/up split
- clamp(0,448) → float8_e4m3fn, transpose (N,K)→(K,N) for CUTLASS

No prepack cache:
- SFB remapped per-call inside CUTLASS (~µs, not the bottleneck)
- See README for why prepack cache must never return (OOM, CUDA graphs,
  M-dependent layout, cross-layer collisions)

Stage activation:
- Nearest-neighbor E2M1 quantization (no clamp, no uniform steps)
- Per-tensor global scale → alpha for L2 GEMM

Bug fixes:
- _fold_global_scale: removed broken logical_widths branch
- unpack_ue4m3_u32: int32 for CUDA bitwise, view not to, ND support
- Correct expert param mapping for NVFP4 checkpoint
- SiLU applied per-slot (not after summing expert paths)
2026-05-15 11:38:18 +00:00

96 lines
3.4 KiB
Python

"""Symmetric buffer for NVLink cross-rank all-reduce in mega_moe.
Replaces deep_gemm.mega.SymmBuffer and get_symm_buffer_for_nvfp4_mega_moe.
API matches the DeepGEMM signature used in the vLLM deepseek_v4.py patch.
"""
import os
import torch
import torch.distributed as dist
MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
class SymmBuffer:
"""Symmetric NVLink buffer for expert-parallel cross-rank communication.
Matches the DeepGEMM SymmBuffer interface expected by the vLLM patch:
- .x: staged activation (FP4 packed)
- .x_sf: staged activation scales (UE4M3 packed)
- .topk_idx: top-k expert indices
- .topk_weights: top-k routing weights
- .buffer: underlying CUDA buffer
- .group: process group
"""
def __init__(self, group, num_experts, max_num_tokens, top_k,
hidden_size, intermediate_size):
self.group = group
self.num_experts = num_experts
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.experts_start_idx = 0 # set by caller before kernel invocation
device = torch.cuda.current_device()
# NVFP4 packed E2M1: 2 FP4 values per byte → K//2 bytes per token.
# Scales are UE4M3 (float8_e4m3fn), one per 16-element group → K//16
# bytes per token, UNPACKED. This is what `stage_activation` produces
# and what the CUTLASS NVFP4 block-scaled GEMM consumes directly.
# (The DeepGEMM API packed 4 UE4M3 into one uint32 — we don't, because
# our CUTLASS kernel reads scales as float8_e4m3fn.)
sf_k_groups_hidden = hidden_size // 16
sf_k_groups_inter = intermediate_size // 16
# Staging buffers
self.x = torch.empty(
max_num_tokens, hidden_size // 2,
dtype=torch.int8, device=device,
)
self.x_sf = torch.empty(
max_num_tokens, sf_k_groups_hidden,
dtype=torch.float8_e4m3fn, device=device,
)
self.topk_idx = torch.empty(
max_num_tokens, top_k,
dtype=torch.int32, device=device,
)
self.topk_weights = torch.empty(
max_num_tokens, top_k,
dtype=torch.float32, device=device,
)
# All-reduce buffer
self.buffer = torch.empty(
max_num_tokens, hidden_size,
dtype=torch.bfloat16, device=device,
)
# Per-tensor global scale from stage_activation (fp32 scalar)
# Applied as GEMM alpha: D = global_scale * (A_sf * A_fp4) @ (B_sf * B_fp4)
self.input_global_scale = 1.0
if MEGA_MOE_DEBUG:
print(f"[SymmBuffer] x={self.x.shape} x_sf={self.x_sf.shape} "
f"topk_idx={self.topk_idx.shape} topk_weights={self.topk_weights.shape} "
f"buffer={self.buffer.shape}")
def get_symm_buffer_for_nvfp4_mega_moe(
group,
num_experts: int,
max_num_tokens: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
) -> SymmBuffer:
"""Allocate a symmetric buffer for the NVFP4 mega_moe kernel.
API matches deep_gemm.mega.get_symm_buffer_for_nvfp4_mega_moe.
"""
return SymmBuffer(
group, num_experts, max_num_tokens, top_k,
hidden_size, intermediate_size,
)