cleanup: remove abandoned TileLang and Mojo files

- Deleted: layout.mojo, mega_moe.mojo, quantize.mojo (Mojo attempt)
- Deleted: nvfp4_blockscaled_gemm.py, staging.py, nvfp4_mega_moe.py (TileLang top-level)
- Deleted: tilelang_nvfp4_gemm.py, tilelang_kernels.py, nvfp4_dequant.py (TileLang package)
- Deleted: src/weight_transform.py (duplicate of package version)
- Fixed nvfp4_mega_moe.py: inlined unpack_ue4m3_u32, removed TileLang fallback imports
- Fixed weight_transform.py: renamed function, removed TileLang alias, updated docs
- Fixed __init__.py: removed TileLang alias, updated docstring
- CUTLASS is the only kernel path now
This commit is contained in:
2026-05-14 12:44:47 +00:00
parent 802c4ee12c
commit d3f35c9465
19 changed files with 36 additions and 1636 deletions

View File

@@ -1,32 +0,0 @@
"""
NVFP4 weight transformation and SF layout utilities.
Port of deep_gemm.mega.transform_nvfp4_weights_for_mega_moe
"""
from math import ceil_div
fn fold_global_scale_into_block_scales(
weight_scale: Tensor[float8_e4m3fn], # (N, K//16) UE4M3 block scales
weight_scale_2: Tensor[float32], # (num_logical,) or scalar global scale
logical_widths: List[int], # per-logical-weight row counts
) -> Tensor[float32]:
"""Fold global scale into block scales: UE4M3 * FP32 -> FP32"""
# Convert UE4M3 to float32, multiply by global scale
# For MergedColumnParallelLinear, expand per-logical global scale
...
fn pack_ue4m3_to_int32(sf: Tensor[float8_e4m3fn]) -> Tensor[int32]:
"""Pack 4 UE4M3 values (4 bytes) into one int32 for DeepGEMM TMA"""
# View as uint8, pack 4 consecutive bytes into int32
...
fn transform_sf_into_required_layout(
sf_mn: Tensor[int32], # MN-major packed SF
N: int, K: int,
recipe: Tuple[int, int], # (gran_mn, gran_k)
num_groups: int,
) -> Tensor[int32]:
"""Transform SF into TMA-aligned UTCCP layout for DeepGEMM"""
# Call into DeepGEMM's C++ layout transform
...

View File

@@ -1,24 +0,0 @@
"""
NVFP4 Mega MoE Kernel — Mojo Rewrite
This is a from-scratch rewrite of the DeepGEMM fp8_nvfp4_mega_moe kernel.
The CUDA C++ version crashes on B200 with CUDA_ERROR_LAUNCH_FAILED in the
SM100_MMA_MXF4NVF4 instruction. This Mojo rewrite aims to:
1. Produce a correct, working NVFP4 mega_moe kernel
2. Be more maintainable than 1200+ lines of CUDA template metaprogramming
3. Leverage Mojo's GPU programming model for cleaner TMA/UMMA integration
NVFP4 format:
- Weights: E2M1 packed (2 x 4-bit values per byte), int8 container
- Block scales: UE4M3 (float8_e4m3fn), group_size=16
- Global scale: float32 scalar
- Activation: FP8 e4m3fn with UE8M0 per-token scales
Key differences from MXFP4 (group_size=32, UE8M0):
- 2 SF K-columns per BLOCK_K (128/16/4=2) instead of 1
- mxf4nvf4 UMMA instruction with scale_vec::4X
- Different weight transformation (gran_k=16 vs gran_k=32)
"""
# TODO: Implement as Mojo GPU kernel
# Waiting for Mojo GPU programming docs / SDK setup

View File

@@ -1,256 +0,0 @@
"""
NVFP4 Block-Scaled GEMM — TileLang
2CTA persistent kernel for Blackwell SM100
Adapted from tilelang/examples/blockscaled_gemm_sm100/gemm_mxfp8_blockscaled_1d1d.py
Key NVFP4 differences from MXFP8:
- sf_granularity_k=16 (UE4M3 block16) instead of 128 (UE8M0)
- 2 SF uint32 words per BLOCK_K instead of 1 (sf_load_period=1, not 4)
- sf_a_id/sf_b_id cycle through 0,1,2,3 within each SF word
- Must call tcgen05_gemm_blockscaled twice per K-block with different SF words
- in_dtype="float4_e2m1fnx2" (packed FP4) instead of "float8_e4m3fn" (FP8)
"""
import torch
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
from tilelang.profiler import do_bench
# NVFP4 constants
NVFP4_SF_GRAN_K = 16 # UE4M3 group_size
NVFP4_SF_PACK = 4 # 4 UE4M3 values per uint32
NVFP4_SF_K_PER_WORD = NVFP4_SF_GRAN_K * NVFP4_SF_PACK # 64 K-elements per uint32 word
# For BLOCK_K=128: 128/64 = 2 SF words per K-block
# Each word covers 4 UMMA sub-atoms (sf_id 0-3)
@tilelang.jit
def nvfp4_blockscaled_gemm_2cta_persistent(
A, # (M, K) packed FP4
B, # (N, K) packed FP4 (K-major for NN)
SFA, # (sf_k_groups * M) uint32 packed UE4M3
SFB, # (sf_k_groups * N) uint32 packed UE4M3
block_M=128,
block_N=256,
block_K=128,
in_dtype="float4_e2m1fnx2",
out_dtype="bfloat16",
accum_dtype="float32",
num_stages=2,
sf_granularity_k=NVFP4_SF_GRAN_K,
use_tma_store=True,
store_block_N=64,
):
M, N, K = T.const("M, N, K")
assert block_M == 128
assert block_N == 256
assert block_K == 128
half_N = block_N // 2
k_iters = T.ceildiv(K, block_K)
# NVFP4 SF layout:
# sf_granularity_k=16, pack_factor=4 → 64 K-elements per uint32
# BLOCK_K=128 → 2 SF words per K-block
# sf_load_period=1 (load every K-block, unlike MXFP8 which loads every 4)
sf_words_per_k_block = block_K // (sf_granularity_k * NVFP4_SF_PACK) # 2
sf_k_groups = T.ceildiv(K, sf_granularity_k * NVFP4_SF_PACK) # total SF groups across K
A: T.Tensor[[M, K], in_dtype]
B: T.Tensor[[N, K], in_dtype]
SFA: T.Tensor[[sf_k_groups * M], T.uint32]
SFB: T.Tensor[[sf_k_groups * N], T.uint32]
C = T.empty((M, N), out_dtype)
sm_num = driver.get_num_sms()
num_clusters = sm_num // 2
m_blocks = T.ceildiv(M, block_M)
m_clusters = m_blocks // 2
n_blocks = T.ceildiv(N, block_N)
waves = T.ceildiv(m_blocks * n_blocks, sm_num)
group_size = 16
assert n_blocks % (2 * group_size) == 0
with T.Kernel(sm_num, threads=256, cluster_dims=2) as (block_id):
cta_id = T.block_rank_in_cluster()
T.assume(cta_id < 2)
# Shared memory — FP4 packed (K is halved because 2 values per byte)
A_shared = T.alloc_shared((num_stages, block_M, block_K // 2), "uint8")
B_shared = T.alloc_shared((num_stages, block_K // 2, half_N), "uint8")
# Shared memory for SF — 2 uint32 words per K-block for NVFP4
SFA_shared = T.alloc_shared((num_stages, block_M, sf_words_per_k_block), T.uint32)
SFB_shared = T.alloc_shared((num_stages, block_N, sf_words_per_k_block), T.uint32)
# Tensor memory
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
SFA_tmem = T.alloc_tmem([block_M, block_M // 128 * 4], T.uint32)
SFB_tmem = T.alloc_tmem([block_M, block_N // 128 * 4], T.uint32)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype)
C_shared = T.alloc_shared((block_M, store_block_N), out_dtype)
loaded = T.alloc_barrier([32] * num_stages)
with_sf_full = T.alloc_cluster_barrier([32 * 2] * num_stages)
consumed = T.alloc_cluster_barrier([1] * num_stages)
tmem_full = T.alloc_cluster_barrier([1])
tmem_empty = T.alloc_cluster_barrier([128 * 2])
tx = T.get_thread_binding()
warp_idx = tx // 32
if warp_idx == 0:
# Warp 0: TMA load
for w in T.serial(waves):
cluster_id = block_id // 2
tile_id = num_clusters * w + cluster_id
bx_cluster = (tile_id // group_size) % m_clusters
bx = bx_cluster * 2 + cta_id
by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size
if bx * block_M < M and by * block_N < N:
for k in T.serial(k_iters):
phase = w * k_iters + k
stage = phase % num_stages
parity = (phase // num_stages) & 1
T.mbarrier_wait_parity(consumed[stage], parity ^ 1)
# TMA load A (packed FP4)
T.tma_copy(
A[bx * block_M, k * block_K],
A_shared[stage, :, :],
barrier=loaded[stage],
)
# TMA load B (packed FP4)
T.tma_copy(
B[k * block_K, by * block_N + cta_id * half_N],
B_shared[stage, :, :],
barrier=loaded[stage],
)
# TMA load SFA — 2 words per K-block for NVFP4
# Word 0 covers K[k*128 : k*128+64]
# Word 1 covers K[k*128+64 : k*128+128]
for sf_word in T.unroll(sf_words_per_k_block):
sf_group = k * sf_words_per_k_block + sf_word
T.tma_copy(
SFA[sf_group * M + bx * block_M :
sf_group * M + (bx + 1) * block_M],
SFA_shared[stage, :, sf_word],
barrier=loaded[stage],
)
# TMA load SFB — 2 words per K-block
for sf_word in T.unroll(sf_words_per_k_block):
sf_group = k * sf_words_per_k_block + sf_word
T.tma_copy(
SFB[sf_group * N + by * block_N :
sf_group * N + (by + 1) * block_N],
SFB_shared[stage, :, sf_word],
barrier=loaded[stage],
)
T.mbarrier_arrive(loaded[stage])
elif warp_idx == 1 and cta_id == 0:
# Warp 1: MMA issue + UTCCP
for w in T.serial(waves):
cluster_id = block_id // 2
tile_id = num_clusters * w + cluster_id
bx_cluster = (tile_id // group_size) % m_clusters
bx = bx_cluster * 2 + cta_id
by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size
if bx * block_M < M and by * block_N < N:
for k in T.serial(k_iters):
phase = w * k_iters + k
stage = phase % num_stages
parity = (phase // num_stages) & 1
T.mbarrier_wait_parity(with_sf_full[stage], parity)
# NVFP4: 2 SF words per K-block → 2 sub-GEMM calls
# Each sub-GEMM covers 64 K-elements (BLOCK_K/2)
for sf_word in T.unroll(sf_words_per_k_block):
# UTCCP: copy SF word from shared to tensor memory
T.tcgen05_cp_warpx4(
SFA_shared[stage, :, sf_word],
SFA_tmem,
use_2cta=True,
)
T.tcgen05_cp_warpx4(
SFB_shared[stage, :, sf_word],
SFB_tmem,
use_2cta=True,
)
# Block-scaled GEMM
# For NVFP4: sf_a_id selects which of 4 packed UE4M3
# values in the uint32 word to use.
# With sf_granularity_k=16 and UMMA_K=64:
# 4 SF values per UMMA atom, sf_id=0 covers all 4
# (the hardware auto-cycles through 0,1,2,3 internally)
T.tcgen05_gemm_blockscaled(
A_shared[stage, :, :],
B_shared[stage, :, :],
C_tmem,
SFA_tmem,
SFB_tmem,
transpose_B=True,
mbar=consumed[stage],
clear_accum=(k == 0 and sf_word == 0),
sf_a_id=0,
sf_b_id=0,
use_2cta=True,
)
T.tcgen05_mma_arrive(tmem_full, arrive_2cta=True)
elif warp_idx == 2:
# Warp 2: SF transpose
for w in T.serial(waves):
cluster_id = block_id // 2
tile_id = num_clusters * w + cluster_id
bx_cluster = (tile_id // group_size) % m_clusters
bx = bx_cluster * 2 + cta_id
by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size
if bx * block_M < M and by * block_N < N:
for k in T.serial(k_iters):
phase = w * k_iters + k
stage = phase % num_stages
parity = (phase // num_stages) & 1
T.mbarrier_wait_parity(loaded[stage], parity)
# Transpose all SF words for this K-block
for sf_word in T.unroll(sf_words_per_k_block):
T.tcgen05_sf_warp_transpose(SFA_shared[stage, :, sf_word])
T.tcgen05_sf_warp_transpose(SFB_shared[stage, :, sf_word])
T.fence_proxy_async()
T.mbarrier_arrive(with_sf_full[stage], 0)
# Epilogue: write C from tmem to global memory
for w in T.serial(waves):
cluster_id = block_id // 2
tile_id = num_clusters * w + cluster_id
bx_cluster = (tile_id // group_size) % m_clusters
bx = bx_cluster * 2 + cta_id
by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size
if bx * block_M < M and by * block_N < N:
T.mbarrier_wait_parity(tmem_full, w & 1)
T.copy(C_tmem, C_local)
T.copy(C_local, C_local_cast)
T.copy(C_local_cast, C_shared)
if use_tma_store:
T.copy(C_shared, C[bx * block_M, by * block_N])
else:
T.copy(C_shared, C[bx * block_M, by * block_N], disable_tma=True)
T.mbarrier_arrive(tmem_empty, 0)
return C

View File

@@ -1,163 +0,0 @@
"""
NVFP4 Mega MoE Kernel — Full MoE with expert parallelism.
This is the main kernel that replaces fp8_nvfp4_mega_moe from DeepGEMM.
Architecture:
- L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with UE4M3 scales)
- SiLU+Mul activation
- L2 GEMM: down_proj (FP4 x FP4 → BF16 with UE4M3 scales)
- NVLink cross-rank sync via symm buffer
- Expert parallel: each rank handles NUM_EXPERTS/8 experts
The kernel is written in TileLang, compiled to SM100 (Blackwell) CUBIN.
"""
import torch
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
# DeepSeek-V4-Pro dimensions
HIDDEN = 7168
INTERMEDIATE = 3072
NUM_EXPERTS = 256
NUM_RANKS = 8
NUM_TOPK = 6
# Block sizes for the block-scaled GEMM
BLOCK_M = 192 # tokens per tile
BLOCK_K = 128
BLOCK_N = 128
# NVFP4 scale parameters
SF_GRANULARITY_K = 16 # UE4M3 group_size
SF_PACK_FACTOR = 4 # 4 UE4M3 values per uint32
SF_WORDS_PER_K_BLOCK = BLOCK_K // (SF_GRANULARITY_K * SF_PACK_FACTOR) # 2
@tilelang.jit
def nvfp4_mega_moe_l1(
# Activation (packed FP4, from staging kernel)
X, # (num_tokens, K//2) int8 packed E2M1
X_SF, # (num_tokens, sf_k_groups) uint32 packed UE4M3
# L1 weights (pre-transformed for mega_moe)
L1_W, # (num_experts_per_rank, 2*INTERMEDIATE, K//2) int8 K-major
L1_SF, # (num_experts_per_rank, 2*INTERMEDIATE, sf_k_groups) uint32 TMA-aligned
# Routing
TopkIdx, # (num_tokens, NUM_TOPK) int32
TopkW, # (num_tokens, NUM_TOPK) float32
# Output
Y, # (num_tokens, 2*INTERMEDIATE) bfloat16
num_experts_per_rank,
):
"""
L1 GEMM for mega_moe: gate_up_proj
X(FP4) @ L1_W(FP4)^T → Y(BF16) with UE4M3 block scaling
This handles the gate_up_proj for all experts in the rank.
Each token is routed to NUM_TOPK experts, and the GEMM is computed
per expert group using persistent scheduling.
"""
num_tokens = T.const("num_tokens")
K = T.const("K")
INTER = T.const("INTERMEDIATE")
k_iters = T.ceildiv(K, BLOCK_K)
sf_k_groups = T.ceildiv(K, SF_GRANULARITY_K * SF_PACK_FACTOR)
with T.Kernel(
T.ceildiv(num_tokens, BLOCK_M),
T.ceildiv(2 * INTER, BLOCK_N),
threads=256,
cluster_dims=2,
) as (bx, by):
cta_id = T.block_rank_in_cluster()
T.assume(cta_id < 2)
# ... (persistent scheduling, expert routing, L1 GEMM with block scaling)
# This follows the same pattern as nvfp4_blockscaled_gemm_2cta_persistent
# but adds expert scheduling and topk weight scaling
pass # Full implementation in progress
@tilelang.jit
def nvfp4_mega_moe_l2(
# L1 output (quantized to FP4 for L2 input)
X, # (num_tokens, INTER//2) int8 packed E2M1
X_SF, # (num_tokens, sf_k_groups) uint32 packed UE4M3
# L2 weights
L2_W, # (num_experts_per_rank, HIDDEN, INTER//2) int8 K-major
L2_SF, # (num_experts_per_rank, HIDDEN, sf_k_groups) uint32 TMA-aligned
# Routing
TopkIdx, # (num_tokens, NUM_TOPK) int32
TopkW, # (num_tokens, NUM_TOPK) float32
# Output
Y, # (num_tokens, HIDDEN) bfloat16
num_experts_per_rank,
):
"""
L2 GEMM for mega_moe: down_proj
X(FP4) @ L2_W(FP4)^T → Y(BF16) with UE4M3 block scaling
After SiLU+Mul on the L1 output, the result is quantized to FP4
and fed into L2.
"""
pass # Symmetric to L1
def nvfp4_mega_moe_full(
hidden_states, # (num_tokens, HIDDEN) bfloat16
topk_weights, # (num_tokens, NUM_TOPK) float32
topk_ids, # (num_tokens, NUM_TOPK) int32
l1_weights, # L1 weights (transformed for mega_moe)
l1_scales, # L1 UE4M3 scales (transformed)
l2_weights, # L2 weights (transformed for mega_moe)
l2_scales, # L2 UE4M3 scales (transformed)
symm_buffer, # NVLink symm buffer for cross-rank sync
):
"""
Full mega_moe forward pass:
1. Stage: quantize BF16 hidden_states → FP4 + UE4M3 scales
2. L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with block scaling)
3. SiLU + Mul (activation)
4. Quantize L1 output → FP4 + UE4M3 scales
5. L2 GEMM: down_proj (FP4 x FP4 → BF16 with block scaling)
6. NVLink sync + reduce across ranks
"""
# Step 1: Stage activation (BF16 → FP4 quantization)
# This is the staging kernel from patches/staging_kernel.py
x_fp4, x_sf = stage_activation(hidden_states)
# Step 2: L1 GEMM
l1_output = nvfp4_mega_moe_l1(
x_fp4, x_sf, l1_weights, l1_scales, topk_ids, topk_weights)
# Step 3: SiLU + Mul
gate, up = l1_output.chunk(2, dim=-1)
activated = torch.nn.functional.silu(gate) * up
# Step 4: Quantize L1 output → FP4
l1_fp4, l1_sf = stage_activation(activated)
# Step 5: L2 GEMM
l2_output = nvfp4_mega_moe_l2(
l1_fp4, l1_sf, l2_weights, l2_scales, topk_ids, topk_weights)
# Step 6: NVLink reduce
output = nvlink_reduce(l2_output, topk_weights, symm_buffer)
return output
def stage_activation(x_bf16):
"""Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales.
This replaces the Triton staging kernel from patches/staging_kernel.py.
"""
# E2M1 quantization with UE4M3 block scaling
# For now, use PyTorch reference implementation
# TODO: Write as TileLang kernel for full pipeline integration
from deep_gemm import per_token_cast_to_fp4
return per_token_cast_to_fp4(x_bf16)

View File

@@ -1,4 +1,4 @@
"""NVFP4 Mega MoE Kernel — TileLang implementation for DeepSeek-V4-Pro on Blackwell."""
"""NVFP4 Mega MoE Kernel — CUTLASS implementation for DeepSeek-V4-Pro on Blackwell."""
from nvfp4_megamoe_kernel.nvfp4_mega_moe import (
nvfp4_mega_moe_full,
@@ -7,7 +7,7 @@ from nvfp4_megamoe_kernel.nvfp4_mega_moe import (
stage_activation,
)
from nvfp4_megamoe_kernel.weight_transform import (
transform_nvfp4_weights_for_tilelang as transform_nvfp4_weights_for_mega_moe,
transform_nvfp4_weights_for_mega_moe,
)
from nvfp4_megamoe_kernel.symm_buffer import (
SymmBuffer,

View File

@@ -1,71 +0,0 @@
"""
NVFP4 dequantization utilities.
Converts packed E2M1 (int8) + UE4M3 block16 scales to BF16.
"""
import torch
def unpack_ue4m3_u32(packed: torch.Tensor) -> torch.Tensor:
"""Unpack uint32 packed UE4M3 (4 values per uint32) to float8_e4m3fn.
Args:
packed: (..., sf_k_groups) uint32 — 4 UE4M3 values packed per element
Returns:
(..., sf_k_groups * 4) float8_e4m3fn
"""
u32 = packed.to(torch.int32)
b0 = (u32 & 0xFF).to(torch.uint8)
b1 = ((u32 >> 8) & 0xFF).to(torch.uint8)
b2 = ((u32 >> 16) & 0xFF).to(torch.uint8)
b3 = ((u32 >> 24) & 0xFF).to(torch.uint8)
interleaved = torch.stack([b0, b1, b2, b3], dim=-1)
return interleaved.reshape(*packed.shape[:-1], -1).contiguous().view(torch.float8_e4m3fn)
def unpack_e2m1_to_bf16(
packed: torch.Tensor, # (..., K//2) int8 — two E2M1 values per byte
scales: torch.Tensor, # (..., K//16) float8_e4m3fn — UE4M3 block16 scales
) -> torch.Tensor:
"""Dequantize packed E2M1 with UE4M3 block16 scales to BF16.
E2M1 format: sign(1) exponent(2) mantissa(1), bias=2
Each int8 byte contains 2 E2M1 values: low nibble=element 0, high nibble=element 1.
UE4M3 block scales: one float8_e4m3fn scale per group of 16 consecutive elements.
Args:
packed: (..., K//2) int8 packed E2M1
scales: (..., K//16) float8_e4m3fn UE4M3 block16 scales
Returns:
(..., K) bfloat16
"""
u8 = packed.view(torch.uint8)
lo = (u8 & 0x0F).to(torch.int32) # lower nibble
hi = (u8 >> 4).to(torch.int32) # upper nibble
# Interleave: (..., K//2, 2) → (..., K)
unpacked = torch.stack([lo, hi], dim=-1).reshape(*u8.shape[:-1], -1)
# E2M1 → float32
sign = (unpacked >> 3).to(torch.float32) * -2.0 + 1.0
exp_field = (unpacked >> 1) & 0x3
mant = (unpacked & 0x1).to(torch.float32)
# E2M1 value = sign * 2^(exp - 2) * (1 + mant * 0.5)
val = sign * (2.0 ** (exp_field.to(torch.float32) - 2.0)) * (1.0 + mant * 0.5)
# Zero: exp=0 and mant=0
zero_mask = (exp_field == 0) & ((unpacked & 1) == 0)
val = val * (~zero_mask).to(torch.float32)
# Apply UE4M3 block16 scales
sf_f32 = scales.to(torch.float32)
sf_expanded = sf_f32.repeat_interleave(16, dim=-1)
K = unpacked.shape[-1]
sf_expanded = sf_expanded[..., :K]
return (val * sf_expanded).to(torch.bfloat16)

View File

@@ -26,12 +26,24 @@ avoiding the costly dequantization step.
import os
import torch
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32
from nvfp4_megamoe_kernel.tilelang_nvfp4_gemm import (
nvfp4_blockscaled_gemm,
grouped_gemm_nvfp4_native,
grouped_gemm_nvfp4_packed_sf,
)
def unpack_ue4m3_u32(x_u32):
"""Unpack uint32 packed UE4M3 scales to float8_e4m3fn.
Each uint32 contains 4 UE4M3 values packed in bits [0:8], [8:16], [16:24], [24:32].
"""
x_u32 = x_u32.contiguous()
M, N = x_u32.shape
out = torch.empty(M, N * 4, dtype=torch.float8_e4m3fn, device=x_u32.device)
# Vectorized unpack: extract 4 bytes from each uint32
b0 = (x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)
b1 = ((x_u32 >> 8) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)
b2 = ((x_u32 >> 16) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)
b3 = ((x_u32 >> 24) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)
out[:, 0::4] = b0
out[:, 1::4] = b1
out[:, 2::4] = b2
out[:, 3::4] = b3
return out
# CUTLASS native NVFP4 block-scaled GEMM (SM100 Blackwell)
# Primary path: uses CUTLASS MainloopSm100TmaUmmaWarpSpecializedBlockScaled
@@ -97,21 +109,11 @@ def nvfp4_mega_moe_l1(
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
# Use CUTLASS native block-scaled GEMM if available
if MEGA_MOE_USE_CUTLASS and _CUTLASS_AVAILABLE:
output = cutlass_grouped_nvfp4_gemm(
x_fp4, x_sf_fp8,
l1_weights, w_sf_fp8,
topk_ids, topk_weights,
)
else:
# Fallback to TileLang path
output = grouped_gemm_nvfp4_native(
x_fp4, x_sf_fp8,
l1_weights, w_sf_fp8,
topk_ids, topk_weights,
)
output = cutlass_grouped_nvfp4_gemm(
x_fp4, x_sf_fp8,
l1_weights, w_sf_fp8,
topk_ids, topk_weights,
)
return output # (num_tokens, 6144) bfloat16
@@ -141,21 +143,11 @@ def nvfp4_mega_moe_l2(
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales
# Use CUTLASS native block-scaled GEMM if available
if MEGA_MOE_USE_CUTLASS and _CUTLASS_AVAILABLE:
output = cutlass_grouped_nvfp4_gemm(
x_fp4, x_sf_fp8,
l2_weights, w_sf_fp8,
topk_ids, topk_weights,
)
else:
# Fallback to TileLang path
output = grouped_gemm_nvfp4_native(
x_fp4, x_sf_fp8,
l2_weights, w_sf_fp8,
topk_ids, topk_weights,
)
output = cutlass_grouped_nvfp4_gemm(
x_fp4, x_sf_fp8,
l2_weights, w_sf_fp8,
topk_ids, topk_weights,
)
return output # (num_tokens, 7168) bfloat16

View File

@@ -1,136 +0,0 @@
"""
TileLang NVFP4 Mega MoE Kernels — BF16 GEMM with FP4 dequantization.
This module provides the core GEMM kernels for the DeepSeek-V4-Pro MoE layer:
- L1 (gate_up_proj): HIDDEN→2*INTERMEDIATE, FP4 weights + UE4M3 scales
- L2 (down_proj): INTERMEDIATE→HIDDEN, FP4 weights + UE4M3 scales
Current approach: Dequantize FP4→BF16, then run BF16 GEMM via TileLang.
This is correct and functional. Once TileLang adds native tcgen05.mma
kind::mxf8f6f4.block_scale support for E2M1+UE4M3, we'll switch to
native FP4 block-scaled MMA for maximum throughput.
The per-expert GEMM uses a "segmented" approach: sort tokens by expert,
batched GEMM per expert using TileLang-compiled BF16 kernels.
"""
import torch
import tilelang
import tilelang.language as T
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32
# ---------------------------------------------------------------------------
# TileLang BF16 GEMM kernel (auto-detects Blackwell, lowers to tcgen05)
# ---------------------------------------------------------------------------
_kernel_cache = {}
def _make_bf16_gemm(M, N, K, block_M=128, block_N=128, block_K=128, num_stages=3):
"""Build and cache a TileLang BF16 GEMM kernel for the given dimensions."""
key = (M, N, K, block_M, block_N, block_K, num_stages)
if key in _kernel_cache:
return _kernel_cache[key]
@tilelang.jit(out_idx=[2])
def bf16_gemm(
A: T.Tensor((M, K), T.bfloat16),
B: T.Tensor((K, N), T.bfloat16),
C: T.Tensor((M, N), T.bfloat16),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), T.bfloat16)
B_shared = T.alloc_shared((block_K, block_N), T.bfloat16)
C_local = T.alloc_fragment((block_M, block_N), T.float32)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
_kernel_cache[key] = bf16_gemm
return bf16_gemm
# ---------------------------------------------------------------------------
# Grouped expert GEMM with FP4 dequantization
# ---------------------------------------------------------------------------
def grouped_gemm_fp4(
x_bf16: torch.Tensor, # (total_tokens, K_dim) bfloat16
weights_fp4: torch.Tensor, # (E, N, K//2) int8 packed E2M1
scales_ue4m3: torch.Tensor, # (E, N, K//16) float8_e4m3fn
topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32
topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32
) -> torch.Tensor:
"""Segmented grouped expert GEMM: dequantize FP4→BF16, per-expert GEMM.
Strategy:
1. Sort tokens by expert assignment
2. For each expert, dequantize its weight to BF16 (cached)
3. Run batched BF16 GEMM using TileLang-compiled kernels
4. Scatter results back with routing weights
"""
num_tokens, K_dim = x_bf16.shape
E, N, K_half = weights_fp4.shape
K = K_half * 2
assert K == K_dim, f"Activation K={K_dim} doesn't match weight K={K}"
top_k = topk_ids.shape[1]
device = x_bf16.device
output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=device)
# Pre-compute expert weight dequantization (cache for repeated use)
# For 32 experts, this is manageable
w_bf16_cache = {}
for e in range(E):
w_bf16_cache[e] = unpack_e2m1_to_bf16(weights_fp4[e], scales_ue4m3[e]) # (N, K)
# Process per expert
for e in range(E):
# Find all (token, k_idx) pairs for this expert
mask = (topk_ids == e) # (num_tokens, top_k)
if not mask.any():
continue
w_bf16 = w_bf16_cache[e] # (N, K)
# Collect tokens for this expert across all top-k slots
for k_idx in range(top_k):
token_mask = mask[:, k_idx]
if not token_mask.any():
continue
token_indices = token_mask.nonzero(as_tuple=True)[0]
# Gather activations
x_sub = x_bf16[token_indices] # (n, K)
# BF16 GEMM: (n, K) @ (N, K).T → (n, N)
result = torch.nn.functional.linear(x_sub, w_bf16)
# Weighted scatter-add
weights = topk_weights[token_indices, k_idx].unsqueeze(-1)
output[token_indices] += result * weights
return output
# ---------------------------------------------------------------------------
# Convenience: grouped GEMM with uint32 packed scales
# ---------------------------------------------------------------------------
def grouped_gemm_fp4_packed_sf(
x_bf16: torch.Tensor,
weights_fp4: torch.Tensor, # (E, N, K//2) int8
scales_packed: torch.Tensor, # (E, N, sf_k_groups) uint32 packed UE4M3
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
) -> torch.Tensor:
"""Same as grouped_gemm_fp4 but unpacks uint32 packed UE4M3 scales first."""
scales_fp8 = unpack_ue4m3_u32(scales_packed)
return grouped_gemm_fp4(x_bf16, weights_fp4, scales_fp8, topk_ids, topk_weights)

View File

@@ -1,599 +0,0 @@
"""
Native NVFP4 Block-Scaled GEMM using tcgen05.mma kind::mxf8f6f4.block_scale.
This module provides the native NVFP4 tensor core GEMM for Blackwell (SM100).
It uses the mxf8f6f4.block_scale PTX instruction which natively performs
E2M1 × E2M1 multiplication with UE4M3 block-16 scaling in tensor cores.
Architecture:
- A: E2M1 packed (int8, 2 values per byte) in global → SMEM via TMA
- B: E2M1 packed (int8, 2 values per byte) in global → SMEM via TMA
- SFA: UE4M3 (float8_e4m3fn) in global → SMEM via TMA → TMEM via tcgen05.ld
- SFB: UE4M3 (float8_e4m3fn) in global → SMEM via TMA → TMEM via tcgen05.ld
- C: accumulated in TMEM, stored to global memory
The key PTX instruction:
tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [tmem_c],
desc_a, desc_b, idescE, [tsfa_addr], [tsfb_addr], pred;
This is the native hardware path for NVFP4 block-scaled MMA on Blackwell.
No dequantization is performed — the hardware does E2M1×E2M1 with UE4M3
block scaling natively in the tensor cores.
Implementation Strategy:
We compile the CUDA kernel at runtime using torch.utils.cpp_extension.load.
The CUDA kernel uses inline PTX for the mxf8f6f4.block_scale instruction
and TMA for efficient global→SMEM transfers.
For the MoE (Mixture of Experts) use case, we support grouped GEMM where
each expert has its own weight matrix and scale factors, and tokens are
routed to specific experts via top-k indices.
"""
import os
import time
import torch
import tempfile
import subprocess
import hashlib
from typing import Optional
# DeepSeek-V4-Pro dimensions
HIDDEN = 7168
INTERMEDIATE = 3072
NUM_EXPERTS = 256
NUM_RANKS = 8
NUM_TOPK = 6
# Block sizes for the GEMM tiling
BLOCK_M = 128
BLOCK_N = 128
BLOCK_K = 64 # For f8f6f4, atom_k=32 elements = 16 bytes packed; we use 64 for double buffering
MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
# ---------------------------------------------------------------------------
# CUDA kernel source
# ---------------------------------------------------------------------------
NVFP4_BLOCKSCALED_GEMM_CUDA = r"""
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
// PTX instructions for Blackwell tcgen05 MMA with block scaling
// We use inline PTX for mxf8f6f4.block_scale which is not yet in cuda.h
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
#define TCGEN05_BLOCKSCALED_ENABLED 1
#else
#define TCGEN05_BLOCKSCALED_ENABLED 0
#endif
// Helper: initialize tcgen05 SMEM descriptor (64-bit)
// Matches the layout from TileLang's common.h initialize_tcgen05_descriptor
__device__ __forceinline__
uint64_t make_tcgen05_descriptor(
uint32_t smem_addr, // shared memory base address
uint32_t leading_bytes, // bytes between consecutive rows/cols in the leading dim
uint32_t stride_bytes, // bytes between tiles in the stride dim
uint32_t base_offset, // offset within the 256B swizzle block (0-7)
uint32_t leading_abs, // 1 if leading offset is absolute, 0 if relative
uint32_t swizzle_mode // 0=none, 1=32B, 2=64B, 3=128B, 4=128B_base32B
) {
uint32_t lo = (smem_addr >> 4) | (leading_bytes << 16);
uint32_t hi = stride_bytes | (1u << 14) // version = 1
| ((base_offset & 0x7) << 17)
| (leading_abs << 20)
| ((swizzle_mode & 0x7) << 29);
return (static_cast<uint64_t>(hi) << 32) | lo;
}
// ============================================================================
// Native NVFP4 Block-Scaled GEMM Kernel
// ============================================================================
//
// Computes C = A @ B where:
// A is E2M1 packed (int8, 2 vals/byte) with UE4M3 block-16 scales (SFA)
// B is E2M1 packed (int8, 2 vals/byte) with UE4M3 block-16 scales (SFB)
// C is float32 (or bfloat16)
//
// Uses tcgen05.mma.kind::mxf8f6f4.block_scale PTX instruction.
// This performs E2M1 × E2M1 with UE4M3 block scaling natively in tensor cores.
//
// For MoE: each CTA handles one (expert, m_tile, n_tile) work item.
// ============================================================================
template<int BLOCK_M, int BLOCK_N, int BLOCK_K_ELEMS, int NUM_STAGES>
__global__ void __launch_bounds__(128)
nvfp4_blockscaled_gemm_kernel(
// A: (M, K_packed) int8 — K_packed = K/2 (2 E2M1 per byte)
const int8_t* __restrict__ A,
// SFA: (M, K_sf) float8_e4m3fn — K_sf = K/16 (one scale per 16 elements)
const __nv_fp8_e4m3* __restrict__ SFA,
// B: (N, K_packed) int8 — K-major layout
const int8_t* __restrict__ B,
// SFB: (N, K_sf) float8_e4m3fn
const __nv_fp8_e4m3* __restrict__ SFB,
// C: (M, N) float32 output
float* __restrict__ C,
// Dimensions
int M, int N, int K,
// Strides
int64_t stride_a_m, int64_t stride_a_k, // A row/col strides in elements
int64_t stride_sfa_m, int64_t stride_sfa_k, // SFA strides
int64_t stride_b_n, int64_t stride_b_k, // B row/col strides in elements
int64_t stride_sfb_n, int64_t stride_sfb_k, // SFB strides
int64_t stride_c_m, int64_t stride_c_n // C strides
) {
// For SM100+, we would use tcgen05.mma.kind::mxf8f6f4.block_scale
// However, the full implementation requires:
// 1. TMA descriptors for global→SMEM async copies
// 2. tcgen05.ld for SMEM→TMEM scale factor transfer
// 3. TMEM allocation for accumulators and scale factors
// 4. tcgen05.mma.kind::mxf8f6f4.block_scale PTX
// 5. TMEM→global store for results
//
// The TMA descriptor setup requires CUDA runtime APIs that are only
// available in CUDA 13.0+ driver. For now, we implement a fallback
// that does the dequantize+GEMM on tensor cores with BF16 MMA,
// and document the native path for when TMA descriptor APIs are stable.
#if TCGEN05_BLOCKSCALED_ENABLED && 0 // Disabled until TMA APIs are stable
// Native f8f6f4 block-scaled MMA path
// This code path will be enabled once CUDA 13.0 TMA descriptor APIs
// are available in the PyTorch CUDA extension build system.
// ... (TMA + tcgen05.mma.kind::mxf8f6f4.block_scale PTX) ...
#else
// Fallback: Dequantize E2M1+UE4M3 → BF16, then BF16 GEMM
// This uses tcgen05.mma.kind::f16 which is native BF16 tensor core MMA.
// The dequantization is done per-tile in shared memory.
const int tid = threadIdx.x;
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int m_start = by * BLOCK_M;
const int n_start = bx * BLOCK_N;
if (m_start >= M || n_start >= N) return;
const int K_packed = K / 2; // E2M1 packed: 2 per byte
const int K_sf = K / 16; // UE4M3 block16: 1 scale per 16 elements
const int BLOCK_K_packed = BLOCK_K_ELEMS / 2;
const int BLOCK_K_sf = BLOCK_K_ELEMS / 16;
// Shared memory for A (dequantized to BF16), B (dequantized to BF16)
extern __shared__ char smem[];
__nv_bfloat16* sA = reinterpret_cast<__nv_bfloat16*>(smem);
__nv_bfloat16* sB = reinterpret_cast<__nv_bfloat16*>(smem + BLOCK_M * BLOCK_K_ELEMS * sizeof(__nv_bfloat16));
// Register fragment for accumulator
float accum[BLOCK_M * BLOCK_N / 128] = {0}; // Per-thread accumulator
// For simplicity, use wmma-style accumulation
// Each thread in the CTA cooperates on the tile
const int m_valid = min(BLOCK_M, M - m_start);
const int n_valid = min(BLOCK_N, N - n_start);
// Process K in tiles
for (int k_start = 0; k_start < K; k_start += BLOCK_K_ELEMS) {
const int k_valid = min(BLOCK_K_ELEMS, K - k_start);
const int k_packed_valid = k_valid / 2;
const int k_sf_valid = k_valid / 16;
// Load and dequantize A tile: (BLOCK_M, BLOCK_K) BF16
for (int idx = tid; idx < BLOCK_M * BLOCK_K_ELEMS; idx += 128) {
int m_local = idx / BLOCK_K_ELEMS;
int k_local = idx % BLOCK_K_ELEMS;
int m_global = m_start + m_local;
int k_global = k_start + k_local;
if (m_global < M && k_global < K) {
// Load E2M1 value
int8_t packed = A[m_global * stride_a_k + k_global / 2];
int e2m1_val = (k_global & 1) ? (packed >> 4) & 0xF : packed & 0xF;
// E2M1 to float: sign * 2^(exp-2) * (1 + mant*0.5)
float sign = (e2m1_val >> 3) ? -1.0f : 1.0f;
int exp_field = (e2m1_val >> 1) & 0x3;
float mant = (e2m1_val & 1) * 0.5f;
float val = sign * powf(2.0f, exp_field - 2.0f) * (1.0f + mant);
if (exp_field == 0 && !(e2m1_val & 1)) val = 0.0f;
// Apply UE4M3 block scale
int sf_idx = k_global / 16;
__nv_fp8_e4m3 sf = SFA[m_global * stride_sfa_k + sf_idx];
float sf_val = __nv_fp8_e4m3_to_float(sf);
sA[m_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(val * sf_val);
} else {
sA[m_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(0.0f);
}
}
// Load and dequantize B tile: (BLOCK_N, BLOCK_K) BF16 (K-major B)
for (int idx = tid; idx < BLOCK_N * BLOCK_K_ELEMS; idx += 128) {
int n_local = idx / BLOCK_K_ELEMS;
int k_local = idx % BLOCK_K_ELEMS;
int n_global = n_start + n_local;
int k_global = k_start + k_local;
if (n_global < N && k_global < K) {
int8_t packed = B[n_global * stride_b_k + k_global / 2];
int e2m1_val = (k_global & 1) ? (packed >> 4) & 0xF : packed & 0xF;
float sign = (e2m1_val >> 3) ? -1.0f : 1.0f;
int exp_field = (e2m1_val >> 1) & 0x3;
float mant = (e2m1_val & 1) * 0.5f;
float val = sign * powf(2.0f, exp_field - 2.0f) * (1.0f + mant);
if (exp_field == 0 && !(e2m1_val & 1)) val = 0.0f;
int sf_idx = k_global / 16;
__nv_fp8_e4m3 sf = SFB[n_global * stride_sfb_k + sf_idx];
float sf_val = __nv_fp8_e4m3_to_float(sf);
sB[n_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(val * sf_val);
} else {
sB[n_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(0.0f);
}
}
__syncthreads();
// BF16 GEMM accumulation: C += A @ B^T
// Simple per-thread row-major accumulation (not using tensor cores in this fallback)
// In the native path, tcgen05.mma handles this natively
for (int m_local = 0; m_local < m_valid; m_local++) {
for (int n_local = tid; n_local < n_valid; n_local += 128) {
float sum = 0.0f;
for (int k_local = 0; k_local < k_valid; k_local++) {
sum += __bfloat162float(sA[m_local * BLOCK_K_ELEMS + k_local])
* __bfloat162float(sB[n_local * BLOCK_K_ELEMS + k_local]);
}
// Atomic add to output
int m_global = m_start + m_local;
int n_global = n_start + n_local;
atomicAdd(&C[m_global * stride_c_n + n_global], sum);
}
}
__syncthreads();
}
#endif
}
// ============================================================================
// Native NVFP4 Block-Scaled GEMM — Full Pipeline (SM100 native path)
// ============================================================================
//
// This kernel uses the full native pipeline:
// 1. TMA async copy: E2M1 data → SMEM
// 2. TMA async copy: UE4M3 scales → SMEM
// 3. tcgen05.ld: SMEM scales → TMEM
// 4. tcgen05.mma.kind::mxf8f6f4.block_scale: native block-scaled MMA
// 5. TMEM store: results → global memory
//
// Requires CUDA 13.0+ and SM100 (Blackwell) hardware.
// ============================================================================
template<int BLOCK_M, int BLOCK_N, int NUM_STAGES>
__global__ void __launch_bounds__(128, 1)
nvfp4_blockscaled_gemm_native_kernel(
const int8_t* __restrict__ A_packed, // (M, K/2) E2M1 packed
const uint8_t* __restrict__ SFA_packed, // (M, K/16) UE4M3 scales as uint8
const int8_t* __restrict__ B_packed, // (N, K/2) E2M1 packed
const uint8_t* __restrict__ SFB_packed, // (N, K/16) UE4M3 scales as uint8
float* __restrict__ C_out, // (M, N) float32 output
int M, int N, int K_total,
int64_t stride_a, int64_t stride_sfa,
int64_t stride_b, int64_t stride_sfb,
int64_t stride_c
) {
// The native mxf8f6f4.block_scale path
// This requires:
// - TMA tensor map creation (cuTensorMapEncodeTiled) for A, B, SFA, SFB
// - Shared memory with proper swizzle layout for tcgen05 descriptors
// - TMEM allocation for accumulators (C) and scale factors (SFA, SFB)
// - tcgen05.ld for scale factor SMEM→TMEM transfer
// - tcgen05.mma.kind::mxf8f6f4.block_scale for native MMA
// - tcgen05.st for TMEM→global result store
//
// The full implementation requires CUDA 13.0 driver support for:
// - cuTensorMapEncodeTiled with sub-byte types
// - TMEM allocation/deallocation intrinsics
// - tcgen05.ld/st intrinsics
//
// For now, we delegate to the fallback kernel.
// The native path will be enabled in a follow-up when the build
// system supports CUDA 13.0 headers.
// This should never be called — we use the fallback path
assert(0 && "Native path not yet compiled — use fallback");
}
// ============================================================================
// PyTorch bindings
// ============================================================================
torch::Tensor nvfp4_blockscaled_gemm_forward(
torch::Tensor A_packed, // (M, K/2) int8 — E2M1 packed
torch::Tensor SFA, // (M, K/16) float8_e4m3fn or uint8 — UE4M3 block16 scales
torch::Tensor B_packed, // (N, K/2) int8 — E2M1 packed
torch::Tensor SFB, // (N, K/16) float8_e4m3fn or uint8 — UE4M3 block16 scales
int64_t M, int64_t N, int64_t K
) {
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(A_packed.device());
auto C = torch::zeros({M, N}, options);
const int BLOCK_M = 128;
const int BLOCK_N = 128;
const int BLOCK_K = 128;
dim3 grid((N + BLOCK_N - 1) / BLOCK_N, (M + BLOCK_M - 1) / BLOCK_M);
dim3 block(128);
// Calculate shared memory size: A + B in BF16
int smem_size = (BLOCK_M * BLOCK_K + BLOCK_N * BLOCK_K) * sizeof(__nv_bfloat16);
auto stream = c10::cuda::getCurrentCUDAStream();
nvfp4_blockscaled_gemm_kernel<BLOCK_M, BLOCK_N, BLOCK_K, 2><<<grid, block, smem_size, stream>>>(
A_packed.data_ptr<int8_t>(),
reinterpret_cast<const __nv_fp8_e4m3*>(SFA.data_ptr<uint8_t>()),
B_packed.data_ptr<int8_t>(),
reinterpret_cast<const __nv_fp8_e4m3*>(SFB.data_ptr<uint8_t>()),
C.data_ptr<float>(),
M, N, K,
K / 2, 1, // stride_a_m, stride_a_k
K / 16, 1, // stride_sfa_m, stride_sfa_k
K / 2, 1, // stride_b_n, stride_b_k
K / 16, 1, // stride_sfb_n, stride_sfb_k
N, 1 // stride_c_m, stride_c_n
);
return C;
}
TORCH_LIBRARY(nvfp4_blockscaled, m) {
m.def("gemm_forward", &nvfp4_blockscaled_gemm_forward);
}
"""
# ---------------------------------------------------------------------------
# Kernel compilation and caching
# ---------------------------------------------------------------------------
_compiled_ext = None
_ext_lock = None
def _get_compiled_extension():
"""Compile and cache the CUDA extension."""
global _compiled_ext
if _compiled_ext is not None:
return _compiled_ext
import torch.utils.cpp_extension as cpp_ext
with tempfile.TemporaryDirectory() as tmpdir:
cu_path = os.path.join(tmpdir, "nvfp4_blockscaled_gemm.cu")
with open(cu_path, "w") as f:
f.write(NVFP4_BLOCKSCALED_GEMM_CUDA)
ext = cpp_ext.load(
name="nvfp4_blockscaled",
sources=[cu_path],
extra_cuda_cflags=[
"-gencode=arch=compute_100a,code=sm_100a",
"--expt-relaxed-constexpr",
"-DNVFP4_BLOCKSCALED_ENABLED=1",
],
extra_cflags=["-O2"],
verbose=MEGA_MOE_DEBUG,
)
_compiled_ext = ext
return ext
# ---------------------------------------------------------------------------
# Native NVFP4 GEMM API
# ---------------------------------------------------------------------------
def nvfp4_blockscaled_gemm(
A_packed: torch.Tensor, # (M, K//2) int8 — E2M1 packed, K-major
A_scales: torch.Tensor, # (M, K//16) float8_e4m3fn — UE4M3 block16 scales
B_packed: torch.Tensor, # (N, K//2) int8 — E2M1 packed, K-major
B_scales: torch.Tensor, # (N, K//16) float8_e4m3fn — UE4M3 block16 scales
) -> torch.Tensor:
"""Native NVFP4 block-scaled GEMM: C = A @ B^T.
A is (M, K//2) int8 E2M1 packed with (M, K//16) UE4M3 scales.
B is (N, K//2) int8 E2M1 packed with (N, K//16) UE4M3 scales.
C is (M, N) float32.
Uses the native mxf8f6f4.block_scale tensor core instruction on Blackwell.
Falls back to dequantize+BF16-GEMM on non-Blackwell hardware.
"""
M = A_packed.shape[0]
K_half = A_packed.shape[1]
K = K_half * 2
N = B_packed.shape[0]
assert A_packed.dtype == torch.int8, f"A must be int8, got {A_packed.dtype}"
assert B_packed.dtype == torch.int8, f"B must be int8, got {B_packed.dtype}"
assert A_packed.is_cuda and B_packed.is_cuda, "Tensors must be on CUDA"
# Try native path
try:
ext = _get_compiled_extension()
# Ensure scales are uint8 view of float8_e4m3fn
if A_scales.dtype == torch.float8_e4m3fn:
A_sf_u8 = A_scales.view(torch.uint8)
else:
A_sf_u8 = A_scales
if B_scales.dtype == torch.float8_e4m3fn:
B_sf_u8 = B_scales.view(torch.uint8)
else:
B_sf_u8 = B_scales
return ext.gemm_forward(A_packed, A_sf_u8, B_packed, B_sf_u8, M, N, K)
except Exception as e:
if MEGA_MOE_DEBUG:
print(f"[nvfp4_gemm] Native kernel failed, using dequant fallback: {e}")
# Fallback: dequantize and use torch.matmul
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16
A_bf16 = unpack_e2m1_to_bf16(A_packed, A_scales) # (M, K)
B_bf16 = unpack_e2m1_to_bf16(B_packed, B_scales) # (N, K)
return torch.matmul(A_bf16.to(torch.float32), B_bf16.to(torch.float32).t())
# ---------------------------------------------------------------------------
# MoE Grouped GEMM with native NVFP4 block-scaled MMA
# ---------------------------------------------------------------------------
def grouped_gemm_nvfp4_native(
x_packed: torch.Tensor, # (num_tokens, K//2) int8 — E2M1 packed
x_scales: torch.Tensor, # (num_tokens, K//16) UE4M3 scales
weights: torch.Tensor, # (E, N, K//2) int8 — per-expert E2M1 weights
weight_scales: torch.Tensor, # (E, N, K//16) UE4M3 per-expert scales
topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32
topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32
) -> torch.Tensor:
"""Segmented grouped expert GEMM with native NVFP4 block-scaled MMA.
For each expert, runs the native NVFP4 GEMM on tokens routed to it.
Results are scattered back with routing weights.
Args:
x_packed: Packed E2M1 activations (num_tokens, K//2)
x_scales: UE4M3 block16 scales (num_tokens, K//16)
weights: Per-expert E2M1 weights (E, N, K//2)
weight_scales: Per-expert UE4M3 scales (E, N, K//16)
topk_ids: Expert assignments (num_tokens, NUM_TOPK)
topk_weights: Routing weights (num_tokens, NUM_TOPK)
Returns:
(num_tokens, N) bfloat16 output
"""
num_tokens = x_packed.shape[0]
K_half = x_packed.shape[1]
K = K_half * 2
E = weights.shape[0]
N = weights.shape[1]
top_k = topk_ids.shape[1]
device = x_packed.device
output = torch.zeros(num_tokens, N, dtype=torch.float32, device=device)
# Process per expert
for e in range(E):
mask = (topk_ids == e) # (num_tokens, top_k)
if not mask.any():
continue
for k_idx in range(top_k):
token_mask = mask[:, k_idx]
if not token_mask.any():
continue
token_indices = token_mask.nonzero(as_tuple=True)[0]
# Gather activations for this expert
x_sub_packed = x_packed[token_indices] # (n, K//2)
x_sub_scales = x_scales[token_indices] # (n, K//16)
w_packed = weights[e] # (N, K//2)
w_scales = weight_scales[e] # (N, K//16)
# Native NVFP4 GEMM: (n, K) @ (N, K)^T → (n, N)
result = nvfp4_blockscaled_gemm(
x_sub_packed, x_sub_scales,
w_packed, w_scales,
) # (n, N) float32
# Weighted scatter-add
weights_f32 = topk_weights[token_indices, k_idx].unsqueeze(-1)
output[token_indices] += result * weights_f32
return output.to(torch.bfloat16)
# ---------------------------------------------------------------------------
# Convenience wrappers for uint32 packed scales
# ---------------------------------------------------------------------------
def grouped_gemm_nvfp4_packed_sf(
x_packed: torch.Tensor, # (num_tokens, K//2) int8
x_sf_packed: torch.Tensor, # (num_tokens, sf_groups) uint32 packed UE4M3
weights: torch.Tensor, # (E, N, K//2) int8
weight_sf: torch.Tensor, # (E, N, sf_groups) uint32 packed UE4M3
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
) -> torch.Tensor:
"""Grouped GEMM with uint32 packed UE4M3 scales."""
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_ue4m3_u32
x_sf_fp8 = unpack_ue4m3_u32(x_sf_packed) if x_sf_packed.dtype == torch.uint32 else x_sf_packed
w_sf_fp8 = unpack_ue4m3_u32(weight_sf) if weight_sf.dtype == torch.uint32 else weight_sf
return grouped_gemm_nvfp4_native(
x_packed, x_sf_fp8,
weights, w_sf_fp8,
topk_ids, topk_weights,
)
# ---------------------------------------------------------------------------
# TileLang-based NVFP4 GEMM (using T.gemm with float8_e4m3)
# ---------------------------------------------------------------------------
_tilelang_kernel_cache = {}
def _make_tilelang_nvfp4_gemm(M, N, K_packed, block_M=128, block_N=128, block_K=64):
"""Build a TileLang GEMM kernel using float8_e4m3 (f8f6f4) tensor cores.
This uses TileLang's T.gemm() with float8_e4m3 dtype, which lowers to
tcgen05.mma.kind::f8f4 on Blackwell. Note: this path does NOT apply
UE4M3 block scaling natively — it does E2M1 × E2M1 without scales.
For proper NVFP4 with block scaling, use nvfp4_blockscaled_gemm().
The TileLang path is kept for experimentation and benchmarking.
"""
key = (M, N, K_packed, block_M, block_N, block_K)
if key in _tilelang_kernel_cache:
return _tilelang_kernel_cache[key]
import tilelang
import tilelang.language as T
K = K_packed * 2 # Unpacked element count
@tilelang.jit(out_idx=[2])
def fp4_gemm(
A: T.Tensor((M, K_packed), "float8_e4m3"),
B: T.Tensor((K_packed, N), "float8_e4m3"),
C: T.Tensor((M, N), T.float32),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), "float8_e4m3")
B_shared = T.alloc_shared((block_K, block_N), "float8_e4m3")
C_local = T.alloc_fragment((block_M, block_N), T.float32)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K_packed, block_K), num_stages=2):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, num_elems_per_byte=2)
T.copy(C_local, C[by * block_M, bx * block_N])
_tilelang_kernel_cache[key] = fp4_gemm
return fp4_gemm

View File

@@ -1,10 +1,10 @@
"""
NVFP4 Weight Transformation for TileLang mega_moe kernel.
NVFP4 Weight Transformation for CUTLASS mega_moe kernel.
Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float32 global scale)
into the TMA-aligned format expected by the block-scaled GEMM kernel:
- Packed FP4 weights (uint8, K-major)
- Packed UE4M3 scales (uint32, TMA-aligned UTCCP layout)
into the format expected by the CUTLASS block-scaled GEMM kernel:
- Packed FP4 weights (int8, K-major)
- UE4M3 block scales (float8_e4m3fn, row-major — CUTLASS SF remap handles layout)
This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe.
@@ -70,13 +70,13 @@ def _interleave_l1_weights(weight: torch.Tensor) -> torch.Tensor:
return interleaved.contiguous()
def transform_nvfp4_weights_for_tilelang(
def transform_nvfp4_weights_for_mega_moe(
l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
l1_weight_scale_2: torch.Tensor = None, # float32 global scale for L1
l2_weight_scale_2: torch.Tensor = None, # float32 global scale for L2
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
"""Transform NVFP4 weights for the TileLang block-scaled GEMM.
"""Transform NVFP4 weights for the CUTLASS block-scaled GEMM.
Matches the call signature from nightly vLLM deepseek_v4.py finalize_weights.
@@ -113,7 +113,3 @@ def transform_nvfp4_weights_for_tilelang(
l2_weight_out = l2_weight.contiguous()
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)
# Alias for drop-in replacement
transform_nvfp4_weights_for_mega_moe = transform_nvfp4_weights_for_tilelang

View File

@@ -1,95 +0,0 @@
"""
NVFP4 quantization utilities — E2M1 packing and UE4M3 scale handling.
Core math layer for the NVFP4 mega_moe kernel rewrite.
"""
# E2M1 magnitude lookup table (positive values only)
# Index 0-7 maps to: 0, 0.5, 1, 1.5, 2, 3, 4, 6
def e2m1_magnitude(index: Int) -> Float64:
if index == 0: return 0.0
if index == 1: return 0.5
if index == 2: return 1.0
if index == 3: return 1.5
if index == 4: return 2.0
if index == 5: return 3.0
if index == 6: return 4.0
if index == 7: return 6.0
return 0.0
def quantize_e2m1(value: Float64) -> UInt8:
"""Quantize a float64 value to E2M1 (4-bit), returning the 4-bit nibble with sign."""
var sign = 0
var abs_val = value
if value < 0.0:
sign = 1
abs_val = -value
# Find best E2M1 match
var best_idx = 0
var best_err = abs_val # error for idx=0
for i in range(1, 8):
mag = e2m1_magnitude(i)
err = abs(abs_val - mag)
if err < best_err:
best_err = err
best_idx = i
return (sign << 3) | best_idx
def unpack_e2m1(packed: UInt8, idx: Int) -> Float64:
"""Unpack one E2M1 value from a packed byte.
idx=0 -> low nibble, idx=1 -> high nibble.
"""
nibble: UInt8
if idx == 0:
nibble = packed & 0x0F
else:
nibble = (packed >> 4) & 0x0F # keep sign bit
sign = (nibble >> 3) & 1
mag_idx = nibble & 0x07
magnitude = e2m1_magnitude(Int(mag_idx))
if sign:
return -magnitude
return magnitude
def dequantize_nvfp4_weight(
packed_weight: UInt8,
block_scale: Float64,
group_offset: Int,
) -> Float64:
"""Dequantize a single NVFP4 weight element.
weight = E2M1_magnitude * block_scale
(global_scale is already folded into block_scale)
"""
e2m1_value = unpack_e2m1(packed_weight, group_offset)
return e2m1_value * block_scale
def main() raises:
# Test E2M1 quantization round-trip
print("E2M1 quantization test:")
for val in [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]:
packed = quantize_e2m1(val)
unpacked = unpack_e2m1(packed, 0)
print(" ", val, " -> E2M1 -> ", unpacked)
# Test packed byte (two E2M1 values)
print("\nPacked byte test:")
lo = 1.5
hi = -3.0
packed = (quantize_e2m1(hi) << 4) | quantize_e2m1(lo)
print(" lo=", lo, " hi=", hi, " packed=", packed)
print(" unpack lo=", unpack_e2m1(packed, 0), " unpack hi=", unpack_e2m1(packed, 1))
# Test NVFP4 dequantization
print("\nNVFP4 dequantization test:")
packed_w = UInt8(0x36) # low=6.0, high=3.0
scale = 0.5
print(" packed=0x36, scale=0.5, lo=", dequantize_nvfp4_weight(packed_w, scale, 0))
print(" packed=0x36, scale=0.5, hi=", dequantize_nvfp4_weight(packed_w, scale, 1))

View File

@@ -1,65 +0,0 @@
"""
NVFP4 Activation Quantization — BF16 → packed FP4 + UE4M3 block16 scales
Port of patches/staging_kernel.py to TileLang.
This quantizes BF16 hidden states to:
- Packed FP4 (E2M1, 2 values per uint8 byte)
- UE4M3 block16 scales (4 values per uint32 word, group_size=16)
"""
import torch
import tilelang
import tilelang.language as T
@tilelang.jit
def fp4_quant_kernel(
X, # (M, K) bfloat16 input
block_M=128,
block_K=128,
group_size=16, # NVFP4 UE4M3 group_size
):
"""Quantize BF16 → packed FP4 (E2M1) + UE4M3 block16 scales.
Output:
- X_FP4: (M, K//2) uint8 packed E2M1
- X_SF: (M, K//group_size//4) uint32 packed UE4M3 scales
"""
M, K = T.const("M, K")
X: T.Tensor[[M, K], "bfloat16"]
X_FP4 = T.empty((M, K // 2), "uint8")
X_SF = T.empty((M, K // (group_size * 4)), "uint32")
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=128) as (bx, by):
# Each block quantizes a (block_M, block_K) tile
tx = T.get_thread_binding()
row = bx * block_M + tx // (block_K // 2)
col = tx % (block_K // 2)
if row < M and col < K // 2:
# Load 2 BF16 values
val_lo = X[row, col * 2]
val_hi = X[row, col * 2 + 1]
# Quantize to E2M1 (4-bit each, pack 2 per byte)
# ... (E2M1 quantization logic)
packed = T.uint8(0) # placeholder
X_FP4[row, col] = packed
# Compute block scale for group of 16 elements
# ... (UE4M3 scale computation)
return X_FP4, X_SF
def fp4_quant_reference(x_bf16, group_size=16):
"""Reference FP4 quantization using PyTorch + DeepGEMM.
Used for correctness verification of the TileLang kernel.
"""
from deep_gemm import per_token_cast_to_fp4
x_fp4, x_sf = per_token_cast_to_fp4(x_bf16)
return x_fp4, x_sf

View File

@@ -1,147 +0,0 @@
"""
NVFP4 Weight Transformation for TileLang mega_moe kernel.
Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float32 global scale)
into the TMA-aligned format expected by the block-scaled GEMM kernel:
- Packed FP4 weights (uint8, K-major)
- Packed UE4M3 scales (uint32, TMA-aligned UTCCP layout)
This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe.
"""
import torch
import math
def fold_global_scale(
weight_scale: torch.Tensor, # (N, K//16) float8_e4m3fn
weight_scale_2: torch.Tensor, # (num_logical,) or scalar float32
logical_widths: list[int] = None, # per-logical-weight row counts
) -> torch.Tensor:
"""Fold global scale into block scales: UE4M3 * FP32 → UE4M3 → float32.
Returns: (N, K//16) float32 folded block scales.
"""
sf_f32 = weight_scale.to(torch.float32)
if weight_scale_2.numel() == 1:
sf_f32 = sf_f32 * weight_scale_2.to(torch.float32)
elif weight_scale_2.numel() > 1 and logical_widths is not None:
# Per-logical-weight global scale (e.g., gate_up_proj has 2)
expanded = []
for i, w in enumerate(logical_widths):
if i < len(weight_scale_2):
expanded.append(weight_scale_2[i].flatten()[0].expand(w))
global_scale = torch.cat(expanded).to(torch.float32).unsqueeze(1)
sf_f32 = sf_f32 * global_scale
else:
sf_f32 = sf_f32 * weight_scale_2.max().to(torch.float32)
return sf_f32
def pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor:
"""Pack 4 UE4M3 (float8_e4m3fn) values into one uint32.
Input: (..., K//16) float8_e4m3fn
Output: (..., K//64) uint32 (4 values packed per word)
"""
# View as raw bytes
sf_u8 = sf.view(torch.uint8) # (..., K//16) uint8
shape = sf_u8.shape
assert shape[-1] % 4 == 0, f"Last dim {shape[-1]} not divisible by 4"
# Pack 4 consecutive uint8 values into one uint32
packed = (sf_u8[..., 0::4].to(torch.int32) |
(sf_u8[..., 1::4].to(torch.int32) << 8) |
(sf_u8[..., 2::4].to(torch.int32) << 16) |
(sf_u8[..., 3::4].to(torch.int32) << 24))
return packed.contiguous()
def interleave_l1_weights(
weight: torch.Tensor, # (E, 2*INTER, K//2) int8, K-major
) -> torch.Tensor:
"""Interleave L1 (gate_up) weights for 2CTA UMMA.
The gate and up projections are interleaved in groups of 8 rows
to match the UMMA 2CTA schedule.
"""
E, N, K_half = weight.shape
assert N % 16 == 0, f"N={N} not divisible by 16"
# Reshape to (E, N//16, 16, K_half) → interleave pairs of 8
w = weight.view(E, N // 16, 16, K_half)
# Interleave: [gate_0..7, up_0..7, gate_8..15, up_8..15, ...]
# Actually: pairs of 8 rows from gate and up
w_gate = w[:, :, :8, :]
w_up = w[:, :, 8:16, :]
interleaved = torch.stack([w_gate, w_up], dim=3).reshape(E, N, K_half)
return interleaved.contiguous()
def transform_nvfp4_weights_for_tilelang(
weight: torch.Tensor, # (E, N, K//2) int8, K-major packed FP4
weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn UE4M3
weight_scale_2: torch.Tensor, # (E, 1) or (E, 2) float32 global scale
N: int, # output dimension (2*INTER for L1, HIDDEN for L2)
K: int, # input dimension (HIDDEN for L1, INTER for L2)
gran_k: int = 16, # NVFP4 group_size
is_l1: bool = False, # True for gate_up_proj (needs interleaving)
logical_widths: list[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Transform NVFP4 weights for the TileLang block-scaled GEMM.
Returns:
transformed_weight: (E, N, K//2) int8, K-major, contiguously interleaved
transformed_sf: (E, N, K//64) uint32, packed UE4M3 with folded global scale
"""
device = weight.device
# Step 1: Fold global scale into block scales
sf_folded = fold_global_scale(weight_scale, weight_scale_2, logical_widths)
# Step 2: Clamp and convert back to UE4M3
sf_clamped = sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
# Step 3: Pack UE4M3 into uint32
sf_packed = pack_ue4m3_to_uint32(sf_clamped) # (E, N, K//64)
# Step 4: Interleave L1 weights if needed
if is_l1:
transformed_weight = interleave_l1_weights(weight)
else:
transformed_weight = weight.contiguous()
# Step 5: Ensure K-major layout (already K-major from checkpoint)
# The weight should be (E, N, K//2) with stride (N*K//2, K//2, 1) for K-major
return transformed_weight, sf_packed
def main():
"""Test the weight transformation with dummy data."""
E = 4 # small test
N = 64
K = 128
device = "cpu" # No GPU on sandbox
weight = torch.randint(-128, 127, (E, N, K // 2), dtype=torch.int8, device=device)
weight_scale = torch.randn(E, N, K // 16, device=device).abs().to(torch.float8_e4m3fn)
weight_scale_2 = torch.ones(E, 2, dtype=torch.float32, device=device)
transformed_w, transformed_sf = transform_nvfp4_weights_for_tilelang(
weight, weight_scale, weight_scale_2, N, K, is_l1=True,
logical_widths=[32, 32],
)
print(f"Weight: {transformed_w.shape} {transformed_w.dtype}")
print(f"SF: {transformed_sf.shape} {transformed_sf.dtype}")
print(f"Expected SF: (E={E}, N={N}, K//64={K//64})")
print("Weight transformation test passed!")
if __name__ == "__main__":
main()