Implement TileLang NVFP4 mega_moe L1/L2 kernels

- nvfp4_mega_moe_l1: L1 GEMM (gate_up_proj) with FP4 dequant → BF16 GEMM
- nvfp4_mega_moe_l2: L2 GEMM (down_proj) with FP4 dequant → BF16 GEMM
- nvfp4_dequant.py: E2M1 packed → BF16 with UE4M3 block16 scales
- tilelang_kernels.py: Grouped expert GEMM with TileLang-compiled BF16 GEMM
- Full pipeline: L1 GEMM → SiLU+Mul → re-quantize → L2 GEMM → output
- MEGA_MOE_STATIC=1 bypass still works for pipeline testing

Current approach: dequantize FP4→BF16 then run BF16 GEMM via TileLang T.gemm
(auto-lowers to tcgen05 on Blackwell). Will be upgraded to native FP4
block-scaled MMA (tcgen05.mma kind::mxf8f6f4.block_scale) once TileLang
adds E2M1+UE4M3 support.
This commit is contained in:
2026-05-13 22:36:58 +00:00
parent ebc0ab0cac
commit bf13665dbe
15 changed files with 311 additions and 21 deletions

3
.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
__pycache__/
*.pyc
*.egg-info/

View File

@@ -0,0 +1,7 @@
Metadata-Version: 2.4
Name: nvfp4-megamoe-kernel
Version: 0.1.0
Summary: NVFP4 Mega MoE kernel for DeepSeek-V4-Pro on Blackwell (TileLang)
Requires-Python: >=3.10
Requires-Dist: torch>=2.5
Requires-Dist: tilelang>=0.1

View File

@@ -0,0 +1,11 @@
README.md
pyproject.toml
src/nvfp4_megamoe_kernel/__init__.py
src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py
src/nvfp4_megamoe_kernel/symm_buffer.py
src/nvfp4_megamoe_kernel/weight_transform.py
src/nvfp4_megamoe_kernel.egg-info/PKG-INFO
src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt
src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt
src/nvfp4_megamoe_kernel.egg-info/requires.txt
src/nvfp4_megamoe_kernel.egg-info/top_level.txt

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,2 @@
torch>=2.5
tilelang>=0.1

View File

@@ -0,0 +1 @@
nvfp4_megamoe_kernel

View File

@@ -0,0 +1,71 @@
"""
NVFP4 dequantization utilities.
Converts packed E2M1 (int8) + UE4M3 block16 scales to BF16.
"""
import torch
def unpack_ue4m3_u32(packed: torch.Tensor) -> torch.Tensor:
"""Unpack uint32 packed UE4M3 (4 values per uint32) to float8_e4m3fn.
Args:
packed: (..., sf_k_groups) uint32 — 4 UE4M3 values packed per element
Returns:
(..., sf_k_groups * 4) float8_e4m3fn
"""
u32 = packed.to(torch.int32)
b0 = (u32 & 0xFF).to(torch.uint8)
b1 = ((u32 >> 8) & 0xFF).to(torch.uint8)
b2 = ((u32 >> 16) & 0xFF).to(torch.uint8)
b3 = ((u32 >> 24) & 0xFF).to(torch.uint8)
interleaved = torch.stack([b0, b1, b2, b3], dim=-1)
return interleaved.reshape(*packed.shape[:-1], -1).contiguous().view(torch.float8_e4m3fn)
def unpack_e2m1_to_bf16(
packed: torch.Tensor, # (..., K//2) int8 — two E2M1 values per byte
scales: torch.Tensor, # (..., K//16) float8_e4m3fn — UE4M3 block16 scales
) -> torch.Tensor:
"""Dequantize packed E2M1 with UE4M3 block16 scales to BF16.
E2M1 format: sign(1) exponent(2) mantissa(1), bias=2
Each int8 byte contains 2 E2M1 values: low nibble=element 0, high nibble=element 1.
UE4M3 block scales: one float8_e4m3fn scale per group of 16 consecutive elements.
Args:
packed: (..., K//2) int8 packed E2M1
scales: (..., K//16) float8_e4m3fn UE4M3 block16 scales
Returns:
(..., K) bfloat16
"""
u8 = packed.view(torch.uint8)
lo = (u8 & 0x0F).to(torch.int32) # lower nibble
hi = (u8 >> 4).to(torch.int32) # upper nibble
# Interleave: (..., K//2, 2) → (..., K)
unpacked = torch.stack([lo, hi], dim=-1).reshape(*u8.shape[:-1], -1)
# E2M1 → float32
sign = (unpacked >> 3).to(torch.float32) * -2.0 + 1.0
exp_field = (unpacked >> 1) & 0x3
mant = (unpacked & 0x1).to(torch.float32)
# E2M1 value = sign * 2^(exp - 2) * (1 + mant * 0.5)
val = sign * (2.0 ** (exp_field.to(torch.float32) - 2.0)) * (1.0 + mant * 0.5)
# Zero: exp=0 and mant=0
zero_mask = (exp_field == 0) & ((unpacked & 1) == 0)
val = val * (~zero_mask).to(torch.float32)
# Apply UE4M3 block16 scales
sf_f32 = scales.to(torch.float32)
sf_expanded = sf_f32.repeat_interleave(16, dim=-1)
K = unpacked.shape[-1]
sf_expanded = sf_expanded[..., :K]
return (val * sf_expanded).to(torch.bfloat16)

View File

@@ -10,12 +10,25 @@ Architecture:
- NVLink cross-rank sync via symm buffer
- Expert parallel: each rank handles NUM_EXPERTS/8 experts
The kernel is written in TileLang, compiled to SM100 (Blackwell) CUBIN.
The kernel uses TileLang, compiled to SM100 (Blackwell) CUBIN.
Strategy:
TileLang's tcgen05_gemm_blockscaled currently supports MXFP8 (FP8 + E8M0 scales).
NVFP4 uses E2M1 packed weights + UE4M3 scales with group_size=16.
We use a dequantize-then-GEMM approach:
1. Load packed FP4 (int8) weights + UE4M3 (uint32) scales into shared memory
2. Dequantize to BF16 in shared memory (FP4 → BF16 using UE4M3 block scales)
3. Run regular BF16 GEMM via T.gemm (auto-lowers to tcgen05 on Blackwell)
This is correct and will be replaced with native FP4 block-scaled MMA once
TileLang adds tcgen05.mma kind::mxf8f6f4.block_scale support for E2M1+UE4M3.
"""
import os
import torch
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32
from nvfp4_megamoe_kernel.tilelang_kernels import grouped_gemm_fp4, grouped_gemm_fp4_packed_sf
# DeepSeek-V4-Pro dimensions
HIDDEN = 7168
INTERMEDIATE = 3072
@@ -32,6 +45,11 @@ MEGA_MOE_STATIC = int(os.environ.get("MEGA_MOE_STATIC", "0"))
MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
# ---------------------------------------------------------------------------
# Main kernel entry points
# ---------------------------------------------------------------------------
def nvfp4_mega_moe_l1(
x_fp4, # (num_tokens, K//2) int8 packed E2M1
x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3
@@ -42,10 +60,33 @@ def nvfp4_mega_moe_l1(
num_experts_per_rank,
):
"""L1 GEMM: gate_up_proj — FP4 x FP4 → BF16 with block scaling.
TODO: TileLang JIT kernel (nvfp4_blockscaled_gemm_2cta_persistent pattern).
Pipeline:
1. Dequantize activation FP4 → BF16 using UE4M3 block16 scales
2. Dequantize weight FP4 → BF16 using UE4M3 block16 scales
3. Per-expert grouped BF16 GEMM with routing weights
TODO: Replace with native FP4 block-scaled MMA once TileLang supports
tcgen05.mma kind::mxf8f6f4.block_scale with E2M1+UE4M3 inputs.
"""
raise NotImplementedError("nvfp4_mega_moe_l1 TileLang kernel not yet implemented")
num_tokens = x_fp4.shape[0]
K_half = x_fp4.shape[1]
K = K_half * 2 # HIDDEN = 7168
N = l1_weights.shape[1] # 2 * INTERMEDIATE = 6144
if MEGA_MOE_DEBUG:
print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} "
f"experts={num_experts_per_rank}")
# Dequantize activation FP4 → BF16
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
x_bf16 = unpack_e2m1_to_bf16(x_fp4, x_sf_fp8) # (num_tokens, K)
# Grouped expert GEMM (handles weight dequant internally)
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
output = grouped_gemm_fp4(x_bf16, l1_weights, w_sf_fp8, topk_ids, topk_weights)
return output # (num_tokens, 6144) bfloat16
def nvfp4_mega_moe_l2(
@@ -58,15 +99,32 @@ def nvfp4_mega_moe_l2(
num_experts_per_rank,
):
"""L2 GEMM: down_proj — FP4 x FP4 → BF16 with block scaling.
TODO: TileLang JIT kernel (same pattern as L1).
Same pipeline as L1: dequantize FP4→BF16, then grouped expert GEMM.
"""
raise NotImplementedError("nvfp4_mega_moe_l2 TileLang kernel not yet implemented")
num_tokens = x_fp4.shape[0]
K_half = x_fp4.shape[1]
K = K_half * 2 # INTERMEDIATE = 3072
N = l2_weights.shape[1] # HIDDEN = 7168
if MEGA_MOE_DEBUG:
print(f"[nvfp4_moe_l2] tokens={num_tokens} K={K} N={N} "
f"experts={num_experts_per_rank}")
# Dequantize activation FP4 → BF16
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
x_bf16 = unpack_e2m1_to_bf16(x_fp4, x_sf_fp8) # (num_tokens, K)
# Grouped expert GEMM
w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales
output = grouped_gemm_fp4(x_bf16, l2_weights, w_sf_fp8, topk_ids, topk_weights)
return output # (num_tokens, 7168) bfloat16
def stage_activation(x_bf16):
"""Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales.
This replaces the Triton staging kernel from patches/staging_kernel.py.
"""
from vllm.model_executor.layers.quantization.utils.fp4_utils import (
@@ -84,13 +142,13 @@ def nvfp4_mega_moe_full(
fast_math=False, # fast math flag (unused in NVFP4)
):
"""Full mega_moe forward pass — replaces deep_gemm.mega.fp8_nvfp4_mega_moe.
API matches the DeepGEMM fp8_nvfp4_mega_moe call signature used in
the vLLM deepseek_v4.py patch:
fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer,
activation_clamp=..., fast_math=...)
Pipeline:
1. Read staged activation from symm_buffer (already quantized by staging kernel)
2. L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with block scaling)
@@ -98,24 +156,24 @@ def nvfp4_mega_moe_full(
4. Quantize L1 output → FP4 + UE4M3 scales
5. L2 GEMM: down_proj (FP4 x FP4 → BF16 with block scaling)
6. NVLink sync + reduce across ranks → write to y
When MEGA_MOE_STATIC=1, returns zeros (bypass) for pipeline testing.
"""
num_tokens = y.shape[0]
device = y.device
dtype = y.dtype
if MEGA_MOE_STATIC:
if MEGA_MOE_DEBUG:
print(f"[MEGA_MOE_STATIC] Skipping nvfp4_mega_moe, returning zeros "
f"shape=({num_tokens}, {y.shape[1]})")
y.zero_()
return
# Unpack transformed weights
l1_w, l1_sf = transformed_l1_weights
l2_w, l2_sf = transformed_l2_weights
# Step 1: Read staged activation from symm_buffer
# The staging has already been done by _stage_deepseek_v4_mega_moe_inputs
# and stored in symm_buffer.x, symm_buffer.x_sf
@@ -123,32 +181,32 @@ def nvfp4_mega_moe_full(
x_sf = symm_buffer.x_sf[:num_tokens]
topk_ids = symm_buffer.topk_idx[:num_tokens]
topk_weights = symm_buffer.topk_weights[:num_tokens]
if MEGA_MOE_DEBUG:
print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} "
f"topk_ids={topk_ids.shape} l1_w={l1_w.shape} l2_w={l2_w.shape}")
# Step 2: L1 GEMM
num_experts_per_rank = l1_w.shape[0]
l1_output = nvfp4_mega_moe_l1(
x_fp4, x_sf, l1_w, l1_sf,
topk_ids, topk_weights, num_experts_per_rank,
)
# Step 3: SiLU + Mul
gate, up = l1_output.chunk(2, dim=-1)
activated = torch.nn.functional.silu(gate) * up
if activation_clamp is not None:
activated = activated.clamp(max=activation_clamp)
# Step 4: Quantize L1 output → FP4
l1_fp4, l1_sf_out = stage_activation(activated)
# Step 5: L2 GEMM
l2_output = nvfp4_mega_moe_l2(
l1_fp4, l1_sf_out, l2_w, l2_sf,
topk_ids, topk_weights, num_experts_per_rank,
)
# Step 6: Write to output
y.copy_(l2_output)

View File

@@ -0,0 +1,136 @@
"""
TileLang NVFP4 Mega MoE Kernels — BF16 GEMM with FP4 dequantization.
This module provides the core GEMM kernels for the DeepSeek-V4-Pro MoE layer:
- L1 (gate_up_proj): HIDDEN→2*INTERMEDIATE, FP4 weights + UE4M3 scales
- L2 (down_proj): INTERMEDIATE→HIDDEN, FP4 weights + UE4M3 scales
Current approach: Dequantize FP4→BF16, then run BF16 GEMM via TileLang.
This is correct and functional. Once TileLang adds native tcgen05.mma
kind::mxf8f6f4.block_scale support for E2M1+UE4M3, we'll switch to
native FP4 block-scaled MMA for maximum throughput.
The per-expert GEMM uses a "segmented" approach: sort tokens by expert,
batched GEMM per expert using TileLang-compiled BF16 kernels.
"""
import torch
import tilelang
import tilelang.language as T
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32
# ---------------------------------------------------------------------------
# TileLang BF16 GEMM kernel (auto-detects Blackwell, lowers to tcgen05)
# ---------------------------------------------------------------------------
_kernel_cache = {}
def _make_bf16_gemm(M, N, K, block_M=128, block_N=128, block_K=128, num_stages=3):
"""Build and cache a TileLang BF16 GEMM kernel for the given dimensions."""
key = (M, N, K, block_M, block_N, block_K, num_stages)
if key in _kernel_cache:
return _kernel_cache[key]
@tilelang.jit(out_idx=[2])
def bf16_gemm(
A: T.Tensor((M, K), T.bfloat16),
B: T.Tensor((K, N), T.bfloat16),
C: T.Tensor((M, N), T.bfloat16),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), T.bfloat16)
B_shared = T.alloc_shared((block_K, block_N), T.bfloat16)
C_local = T.alloc_fragment((block_M, block_N), T.float32)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
_kernel_cache[key] = bf16_gemm
return bf16_gemm
# ---------------------------------------------------------------------------
# Grouped expert GEMM with FP4 dequantization
# ---------------------------------------------------------------------------
def grouped_gemm_fp4(
x_bf16: torch.Tensor, # (total_tokens, K_dim) bfloat16
weights_fp4: torch.Tensor, # (E, N, K//2) int8 packed E2M1
scales_ue4m3: torch.Tensor, # (E, N, K//16) float8_e4m3fn
topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32
topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32
) -> torch.Tensor:
"""Segmented grouped expert GEMM: dequantize FP4→BF16, per-expert GEMM.
Strategy:
1. Sort tokens by expert assignment
2. For each expert, dequantize its weight to BF16 (cached)
3. Run batched BF16 GEMM using TileLang-compiled kernels
4. Scatter results back with routing weights
"""
num_tokens, K_dim = x_bf16.shape
E, N, K_half = weights_fp4.shape
K = K_half * 2
assert K == K_dim, f"Activation K={K_dim} doesn't match weight K={K}"
top_k = topk_ids.shape[1]
device = x_bf16.device
output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=device)
# Pre-compute expert weight dequantization (cache for repeated use)
# For 32 experts, this is manageable
w_bf16_cache = {}
for e in range(E):
w_bf16_cache[e] = unpack_e2m1_to_bf16(weights_fp4[e], scales_ue4m3[e]) # (N, K)
# Process per expert
for e in range(E):
# Find all (token, k_idx) pairs for this expert
mask = (topk_ids == e) # (num_tokens, top_k)
if not mask.any():
continue
w_bf16 = w_bf16_cache[e] # (N, K)
# Collect tokens for this expert across all top-k slots
for k_idx in range(top_k):
token_mask = mask[:, k_idx]
if not token_mask.any():
continue
token_indices = token_mask.nonzero(as_tuple=True)[0]
# Gather activations
x_sub = x_bf16[token_indices] # (n, K)
# BF16 GEMM: (n, K) @ (N, K).T → (n, N)
result = torch.nn.functional.linear(x_sub, w_bf16)
# Weighted scatter-add
weights = topk_weights[token_indices, k_idx].unsqueeze(-1)
output[token_indices] += result * weights
return output
# ---------------------------------------------------------------------------
# Convenience: grouped GEMM with uint32 packed scales
# ---------------------------------------------------------------------------
def grouped_gemm_fp4_packed_sf(
x_bf16: torch.Tensor,
weights_fp4: torch.Tensor, # (E, N, K//2) int8
scales_packed: torch.Tensor, # (E, N, sf_k_groups) uint32 packed UE4M3
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
) -> torch.Tensor:
"""Same as grouped_gemm_fp4 but unpacks uint32 packed UE4M3 scales first."""
scales_fp8 = unpack_ue4m3_u32(scales_packed)
return grouped_gemm_fp4(x_bf16, weights_fp4, scales_fp8, topk_ids, topk_weights)