Major changes from initial TileLang prototype: Kernel: - CUTLASS NVFP4 block-scaled GEMM (SM100 Blackwell, OpClassBlockScaledTensorOp) - Slot-based dispatch: L1 GEMM → SiLU+Mul per-slot → L2 GEMM → index_add scatter - 1D slot_expert_ids passed to both L1 and L2 (no 2D topk_ids rebuild) - slot_token gathered in cutlass_grouped_nvfp4_gemm when provided SF Remap (source-first): - Iterates logical (m, k_sf) source grid, uses layout_sf(make_coord(m, k_sf)) for CUTLASS dest index — no idx2crd/flatten coordinate extraction - 2D kernel launch: dim3 block(32,8), grid over (K_sf, MN) - Uses cute::cosize() for physical allocation size (not cute::size) - SFA: (MN, K_sf) row-major; SFB: (K_sf, MN) row-major (col-major) Weight transform: - UE4M3 unpack with bit reinterpret (not value cast) - Global scale folding (weight_scale_2) for gate/up split - clamp(0,448) → float8_e4m3fn, transpose (N,K)→(K,N) for CUTLASS No prepack cache: - SFB remapped per-call inside CUTLASS (~µs, not the bottleneck) - See README for why prepack cache must never return (OOM, CUDA graphs, M-dependent layout, cross-layer collisions) Stage activation: - Nearest-neighbor E2M1 quantization (no clamp, no uniform steps) - Per-tensor global scale → alpha for L2 GEMM Bug fixes: - _fold_global_scale: removed broken logical_widths branch - unpack_ue4m3_u32: int32 for CUDA bitwise, view not to, ND support - Correct expert param mapping for NVFP4 checkpoint - SiLU applied per-slot (not after summing expert paths)
96 lines
3.4 KiB
Python
96 lines
3.4 KiB
Python
"""Symmetric buffer for NVLink cross-rank all-reduce in mega_moe.
|
|
|
|
Replaces deep_gemm.mega.SymmBuffer and get_symm_buffer_for_nvfp4_mega_moe.
|
|
API matches the DeepGEMM signature used in the vLLM deepseek_v4.py patch.
|
|
"""
|
|
|
|
import os
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
|
|
|
|
|
|
class SymmBuffer:
|
|
"""Symmetric NVLink buffer for expert-parallel cross-rank communication.
|
|
|
|
Matches the DeepGEMM SymmBuffer interface expected by the vLLM patch:
|
|
- .x: staged activation (FP4 packed)
|
|
- .x_sf: staged activation scales (UE4M3 packed)
|
|
- .topk_idx: top-k expert indices
|
|
- .topk_weights: top-k routing weights
|
|
- .buffer: underlying CUDA buffer
|
|
- .group: process group
|
|
"""
|
|
|
|
def __init__(self, group, num_experts, max_num_tokens, top_k,
|
|
hidden_size, intermediate_size):
|
|
self.group = group
|
|
self.num_experts = num_experts
|
|
self.max_num_tokens = max_num_tokens
|
|
self.top_k = top_k
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.experts_start_idx = 0 # set by caller before kernel invocation
|
|
|
|
device = torch.cuda.current_device()
|
|
|
|
# NVFP4 packed E2M1: 2 FP4 values per byte → K//2 bytes per token.
|
|
# Scales are UE4M3 (float8_e4m3fn), one per 16-element group → K//16
|
|
# bytes per token, UNPACKED. This is what `stage_activation` produces
|
|
# and what the CUTLASS NVFP4 block-scaled GEMM consumes directly.
|
|
# (The DeepGEMM API packed 4 UE4M3 into one uint32 — we don't, because
|
|
# our CUTLASS kernel reads scales as float8_e4m3fn.)
|
|
sf_k_groups_hidden = hidden_size // 16
|
|
sf_k_groups_inter = intermediate_size // 16
|
|
|
|
# Staging buffers
|
|
self.x = torch.empty(
|
|
max_num_tokens, hidden_size // 2,
|
|
dtype=torch.int8, device=device,
|
|
)
|
|
self.x_sf = torch.empty(
|
|
max_num_tokens, sf_k_groups_hidden,
|
|
dtype=torch.float8_e4m3fn, device=device,
|
|
)
|
|
self.topk_idx = torch.empty(
|
|
max_num_tokens, top_k,
|
|
dtype=torch.int32, device=device,
|
|
)
|
|
self.topk_weights = torch.empty(
|
|
max_num_tokens, top_k,
|
|
dtype=torch.float32, device=device,
|
|
)
|
|
|
|
# All-reduce buffer
|
|
self.buffer = torch.empty(
|
|
max_num_tokens, hidden_size,
|
|
dtype=torch.bfloat16, device=device,
|
|
)
|
|
|
|
# Per-tensor global scale from stage_activation (fp32 scalar)
|
|
# Applied as GEMM alpha: D = global_scale * (A_sf * A_fp4) @ (B_sf * B_fp4)
|
|
self.input_global_scale = 1.0
|
|
|
|
if MEGA_MOE_DEBUG:
|
|
print(f"[SymmBuffer] x={self.x.shape} x_sf={self.x_sf.shape} "
|
|
f"topk_idx={self.topk_idx.shape} topk_weights={self.topk_weights.shape} "
|
|
f"buffer={self.buffer.shape}")
|
|
|
|
|
|
def get_symm_buffer_for_nvfp4_mega_moe(
|
|
group,
|
|
num_experts: int,
|
|
max_num_tokens: int,
|
|
top_k: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
) -> SymmBuffer:
|
|
"""Allocate a symmetric buffer for the NVFP4 mega_moe kernel.
|
|
|
|
API matches deep_gemm.mega.get_symm_buffer_for_nvfp4_mega_moe.
|
|
"""
|
|
return SymmBuffer(
|
|
group, num_experts, max_num_tokens, top_k,
|
|
hidden_size, intermediate_size,
|
|
) |