- nvfp4_mega_moe_l1: L1 GEMM (gate_up_proj) with FP4 dequant → BF16 GEMM - nvfp4_mega_moe_l2: L2 GEMM (down_proj) with FP4 dequant → BF16 GEMM - nvfp4_dequant.py: E2M1 packed → BF16 with UE4M3 block16 scales - tilelang_kernels.py: Grouped expert GEMM with TileLang-compiled BF16 GEMM - Full pipeline: L1 GEMM → SiLU+Mul → re-quantize → L2 GEMM → output - MEGA_MOE_STATIC=1 bypass still works for pipeline testing Current approach: dequantize FP4→BF16 then run BF16 GEMM via TileLang T.gemm (auto-lowers to tcgen05 on Blackwell). Will be upgraded to native FP4 block-scaled MMA (tcgen05.mma kind::mxf8f6f4.block_scale) once TileLang adds E2M1+UE4M3 support.
137 lines
5.2 KiB
Python
137 lines
5.2 KiB
Python
"""
|
|
TileLang NVFP4 Mega MoE Kernels — BF16 GEMM with FP4 dequantization.
|
|
|
|
This module provides the core GEMM kernels for the DeepSeek-V4-Pro MoE layer:
|
|
- L1 (gate_up_proj): HIDDEN→2*INTERMEDIATE, FP4 weights + UE4M3 scales
|
|
- L2 (down_proj): INTERMEDIATE→HIDDEN, FP4 weights + UE4M3 scales
|
|
|
|
Current approach: Dequantize FP4→BF16, then run BF16 GEMM via TileLang.
|
|
This is correct and functional. Once TileLang adds native tcgen05.mma
|
|
kind::mxf8f6f4.block_scale support for E2M1+UE4M3, we'll switch to
|
|
native FP4 block-scaled MMA for maximum throughput.
|
|
|
|
The per-expert GEMM uses a "segmented" approach: sort tokens by expert,
|
|
batched GEMM per expert using TileLang-compiled BF16 kernels.
|
|
"""
|
|
|
|
import torch
|
|
import tilelang
|
|
import tilelang.language as T
|
|
|
|
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TileLang BF16 GEMM kernel (auto-detects Blackwell, lowers to tcgen05)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_kernel_cache = {}
|
|
|
|
|
|
def _make_bf16_gemm(M, N, K, block_M=128, block_N=128, block_K=128, num_stages=3):
|
|
"""Build and cache a TileLang BF16 GEMM kernel for the given dimensions."""
|
|
key = (M, N, K, block_M, block_N, block_K, num_stages)
|
|
if key in _kernel_cache:
|
|
return _kernel_cache[key]
|
|
|
|
@tilelang.jit(out_idx=[2])
|
|
def bf16_gemm(
|
|
A: T.Tensor((M, K), T.bfloat16),
|
|
B: T.Tensor((K, N), T.bfloat16),
|
|
C: T.Tensor((M, N), T.bfloat16),
|
|
):
|
|
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
|
|
A_shared = T.alloc_shared((block_M, block_K), T.bfloat16)
|
|
B_shared = T.alloc_shared((block_K, block_N), T.bfloat16)
|
|
C_local = T.alloc_fragment((block_M, block_N), T.float32)
|
|
|
|
T.clear(C_local)
|
|
|
|
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
|
|
T.copy(A[by * block_M, k * block_K], A_shared)
|
|
T.copy(B[k * block_K, bx * block_N], B_shared)
|
|
T.gemm(A_shared, B_shared, C_local)
|
|
|
|
T.copy(C_local, C[by * block_M, bx * block_N])
|
|
|
|
_kernel_cache[key] = bf16_gemm
|
|
return bf16_gemm
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Grouped expert GEMM with FP4 dequantization
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def grouped_gemm_fp4(
|
|
x_bf16: torch.Tensor, # (total_tokens, K_dim) bfloat16
|
|
weights_fp4: torch.Tensor, # (E, N, K//2) int8 packed E2M1
|
|
scales_ue4m3: torch.Tensor, # (E, N, K//16) float8_e4m3fn
|
|
topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32
|
|
topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32
|
|
) -> torch.Tensor:
|
|
"""Segmented grouped expert GEMM: dequantize FP4→BF16, per-expert GEMM.
|
|
|
|
Strategy:
|
|
1. Sort tokens by expert assignment
|
|
2. For each expert, dequantize its weight to BF16 (cached)
|
|
3. Run batched BF16 GEMM using TileLang-compiled kernels
|
|
4. Scatter results back with routing weights
|
|
"""
|
|
num_tokens, K_dim = x_bf16.shape
|
|
E, N, K_half = weights_fp4.shape
|
|
K = K_half * 2
|
|
assert K == K_dim, f"Activation K={K_dim} doesn't match weight K={K}"
|
|
top_k = topk_ids.shape[1]
|
|
device = x_bf16.device
|
|
|
|
output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=device)
|
|
|
|
# Pre-compute expert weight dequantization (cache for repeated use)
|
|
# For 32 experts, this is manageable
|
|
w_bf16_cache = {}
|
|
for e in range(E):
|
|
w_bf16_cache[e] = unpack_e2m1_to_bf16(weights_fp4[e], scales_ue4m3[e]) # (N, K)
|
|
|
|
# Process per expert
|
|
for e in range(E):
|
|
# Find all (token, k_idx) pairs for this expert
|
|
mask = (topk_ids == e) # (num_tokens, top_k)
|
|
if not mask.any():
|
|
continue
|
|
|
|
w_bf16 = w_bf16_cache[e] # (N, K)
|
|
|
|
# Collect tokens for this expert across all top-k slots
|
|
for k_idx in range(top_k):
|
|
token_mask = mask[:, k_idx]
|
|
if not token_mask.any():
|
|
continue
|
|
token_indices = token_mask.nonzero(as_tuple=True)[0]
|
|
|
|
# Gather activations
|
|
x_sub = x_bf16[token_indices] # (n, K)
|
|
|
|
# BF16 GEMM: (n, K) @ (N, K).T → (n, N)
|
|
result = torch.nn.functional.linear(x_sub, w_bf16)
|
|
|
|
# Weighted scatter-add
|
|
weights = topk_weights[token_indices, k_idx].unsqueeze(-1)
|
|
output[token_indices] += result * weights
|
|
|
|
return output
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Convenience: grouped GEMM with uint32 packed scales
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def grouped_gemm_fp4_packed_sf(
|
|
x_bf16: torch.Tensor,
|
|
weights_fp4: torch.Tensor, # (E, N, K//2) int8
|
|
scales_packed: torch.Tensor, # (E, N, sf_k_groups) uint32 packed UE4M3
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Same as grouped_gemm_fp4 but unpacks uint32 packed UE4M3 scales first."""
|
|
scales_fp8 = unpack_ue4m3_u32(scales_packed)
|
|
return grouped_gemm_fp4(x_bf16, weights_fp4, scales_fp8, topk_ids, topk_weights)
|