cleanup: remove abandoned TileLang and Mojo files
- Deleted: layout.mojo, mega_moe.mojo, quantize.mojo (Mojo attempt) - Deleted: nvfp4_blockscaled_gemm.py, staging.py, nvfp4_mega_moe.py (TileLang top-level) - Deleted: tilelang_nvfp4_gemm.py, tilelang_kernels.py, nvfp4_dequant.py (TileLang package) - Deleted: src/weight_transform.py (duplicate of package version) - Fixed nvfp4_mega_moe.py: inlined unpack_ue4m3_u32, removed TileLang fallback imports - Fixed weight_transform.py: renamed function, removed TileLang alias, updated docs - Fixed __init__.py: removed TileLang alias, updated docstring - CUTLASS is the only kernel path now
This commit is contained in:
@@ -1,32 +0,0 @@
|
||||
"""
|
||||
NVFP4 weight transformation and SF layout utilities.
|
||||
|
||||
Port of deep_gemm.mega.transform_nvfp4_weights_for_mega_moe
|
||||
"""
|
||||
|
||||
from math import ceil_div
|
||||
|
||||
fn fold_global_scale_into_block_scales(
|
||||
weight_scale: Tensor[float8_e4m3fn], # (N, K//16) UE4M3 block scales
|
||||
weight_scale_2: Tensor[float32], # (num_logical,) or scalar global scale
|
||||
logical_widths: List[int], # per-logical-weight row counts
|
||||
) -> Tensor[float32]:
|
||||
"""Fold global scale into block scales: UE4M3 * FP32 -> FP32"""
|
||||
# Convert UE4M3 to float32, multiply by global scale
|
||||
# For MergedColumnParallelLinear, expand per-logical global scale
|
||||
...
|
||||
|
||||
fn pack_ue4m3_to_int32(sf: Tensor[float8_e4m3fn]) -> Tensor[int32]:
|
||||
"""Pack 4 UE4M3 values (4 bytes) into one int32 for DeepGEMM TMA"""
|
||||
# View as uint8, pack 4 consecutive bytes into int32
|
||||
...
|
||||
|
||||
fn transform_sf_into_required_layout(
|
||||
sf_mn: Tensor[int32], # MN-major packed SF
|
||||
N: int, K: int,
|
||||
recipe: Tuple[int, int], # (gran_mn, gran_k)
|
||||
num_groups: int,
|
||||
) -> Tensor[int32]:
|
||||
"""Transform SF into TMA-aligned UTCCP layout for DeepGEMM"""
|
||||
# Call into DeepGEMM's C++ layout transform
|
||||
...
|
||||
@@ -1,24 +0,0 @@
|
||||
"""
|
||||
NVFP4 Mega MoE Kernel — Mojo Rewrite
|
||||
|
||||
This is a from-scratch rewrite of the DeepGEMM fp8_nvfp4_mega_moe kernel.
|
||||
The CUDA C++ version crashes on B200 with CUDA_ERROR_LAUNCH_FAILED in the
|
||||
SM100_MMA_MXF4NVF4 instruction. This Mojo rewrite aims to:
|
||||
1. Produce a correct, working NVFP4 mega_moe kernel
|
||||
2. Be more maintainable than 1200+ lines of CUDA template metaprogramming
|
||||
3. Leverage Mojo's GPU programming model for cleaner TMA/UMMA integration
|
||||
|
||||
NVFP4 format:
|
||||
- Weights: E2M1 packed (2 x 4-bit values per byte), int8 container
|
||||
- Block scales: UE4M3 (float8_e4m3fn), group_size=16
|
||||
- Global scale: float32 scalar
|
||||
- Activation: FP8 e4m3fn with UE8M0 per-token scales
|
||||
|
||||
Key differences from MXFP4 (group_size=32, UE8M0):
|
||||
- 2 SF K-columns per BLOCK_K (128/16/4=2) instead of 1
|
||||
- mxf4nvf4 UMMA instruction with scale_vec::4X
|
||||
- Different weight transformation (gran_k=16 vs gran_k=32)
|
||||
"""
|
||||
|
||||
# TODO: Implement as Mojo GPU kernel
|
||||
# Waiting for Mojo GPU programming docs / SDK setup
|
||||
@@ -1,256 +0,0 @@
|
||||
"""
|
||||
NVFP4 Block-Scaled GEMM — TileLang
|
||||
2CTA persistent kernel for Blackwell SM100
|
||||
|
||||
Adapted from tilelang/examples/blockscaled_gemm_sm100/gemm_mxfp8_blockscaled_1d1d.py
|
||||
|
||||
Key NVFP4 differences from MXFP8:
|
||||
- sf_granularity_k=16 (UE4M3 block16) instead of 128 (UE8M0)
|
||||
- 2 SF uint32 words per BLOCK_K instead of 1 (sf_load_period=1, not 4)
|
||||
- sf_a_id/sf_b_id cycle through 0,1,2,3 within each SF word
|
||||
- Must call tcgen05_gemm_blockscaled twice per K-block with different SF words
|
||||
- in_dtype="float4_e2m1fnx2" (packed FP4) instead of "float8_e4m3fn" (FP8)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
from tilelang.carver.arch import driver
|
||||
from tilelang.profiler import do_bench
|
||||
|
||||
# NVFP4 constants
|
||||
NVFP4_SF_GRAN_K = 16 # UE4M3 group_size
|
||||
NVFP4_SF_PACK = 4 # 4 UE4M3 values per uint32
|
||||
NVFP4_SF_K_PER_WORD = NVFP4_SF_GRAN_K * NVFP4_SF_PACK # 64 K-elements per uint32 word
|
||||
# For BLOCK_K=128: 128/64 = 2 SF words per K-block
|
||||
# Each word covers 4 UMMA sub-atoms (sf_id 0-3)
|
||||
|
||||
|
||||
@tilelang.jit
|
||||
def nvfp4_blockscaled_gemm_2cta_persistent(
|
||||
A, # (M, K) packed FP4
|
||||
B, # (N, K) packed FP4 (K-major for NN)
|
||||
SFA, # (sf_k_groups * M) uint32 packed UE4M3
|
||||
SFB, # (sf_k_groups * N) uint32 packed UE4M3
|
||||
block_M=128,
|
||||
block_N=256,
|
||||
block_K=128,
|
||||
in_dtype="float4_e2m1fnx2",
|
||||
out_dtype="bfloat16",
|
||||
accum_dtype="float32",
|
||||
num_stages=2,
|
||||
sf_granularity_k=NVFP4_SF_GRAN_K,
|
||||
use_tma_store=True,
|
||||
store_block_N=64,
|
||||
):
|
||||
M, N, K = T.const("M, N, K")
|
||||
|
||||
assert block_M == 128
|
||||
assert block_N == 256
|
||||
assert block_K == 128
|
||||
|
||||
half_N = block_N // 2
|
||||
k_iters = T.ceildiv(K, block_K)
|
||||
|
||||
# NVFP4 SF layout:
|
||||
# sf_granularity_k=16, pack_factor=4 → 64 K-elements per uint32
|
||||
# BLOCK_K=128 → 2 SF words per K-block
|
||||
# sf_load_period=1 (load every K-block, unlike MXFP8 which loads every 4)
|
||||
sf_words_per_k_block = block_K // (sf_granularity_k * NVFP4_SF_PACK) # 2
|
||||
sf_k_groups = T.ceildiv(K, sf_granularity_k * NVFP4_SF_PACK) # total SF groups across K
|
||||
|
||||
A: T.Tensor[[M, K], in_dtype]
|
||||
B: T.Tensor[[N, K], in_dtype]
|
||||
SFA: T.Tensor[[sf_k_groups * M], T.uint32]
|
||||
SFB: T.Tensor[[sf_k_groups * N], T.uint32]
|
||||
C = T.empty((M, N), out_dtype)
|
||||
|
||||
sm_num = driver.get_num_sms()
|
||||
num_clusters = sm_num // 2
|
||||
m_blocks = T.ceildiv(M, block_M)
|
||||
m_clusters = m_blocks // 2
|
||||
n_blocks = T.ceildiv(N, block_N)
|
||||
waves = T.ceildiv(m_blocks * n_blocks, sm_num)
|
||||
group_size = 16
|
||||
assert n_blocks % (2 * group_size) == 0
|
||||
|
||||
with T.Kernel(sm_num, threads=256, cluster_dims=2) as (block_id):
|
||||
cta_id = T.block_rank_in_cluster()
|
||||
T.assume(cta_id < 2)
|
||||
|
||||
# Shared memory — FP4 packed (K is halved because 2 values per byte)
|
||||
A_shared = T.alloc_shared((num_stages, block_M, block_K // 2), "uint8")
|
||||
B_shared = T.alloc_shared((num_stages, block_K // 2, half_N), "uint8")
|
||||
|
||||
# Shared memory for SF — 2 uint32 words per K-block for NVFP4
|
||||
SFA_shared = T.alloc_shared((num_stages, block_M, sf_words_per_k_block), T.uint32)
|
||||
SFB_shared = T.alloc_shared((num_stages, block_N, sf_words_per_k_block), T.uint32)
|
||||
|
||||
# Tensor memory
|
||||
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
|
||||
SFA_tmem = T.alloc_tmem([block_M, block_M // 128 * 4], T.uint32)
|
||||
SFB_tmem = T.alloc_tmem([block_M, block_N // 128 * 4], T.uint32)
|
||||
|
||||
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
|
||||
C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype)
|
||||
C_shared = T.alloc_shared((block_M, store_block_N), out_dtype)
|
||||
|
||||
loaded = T.alloc_barrier([32] * num_stages)
|
||||
with_sf_full = T.alloc_cluster_barrier([32 * 2] * num_stages)
|
||||
consumed = T.alloc_cluster_barrier([1] * num_stages)
|
||||
tmem_full = T.alloc_cluster_barrier([1])
|
||||
tmem_empty = T.alloc_cluster_barrier([128 * 2])
|
||||
|
||||
tx = T.get_thread_binding()
|
||||
warp_idx = tx // 32
|
||||
|
||||
if warp_idx == 0:
|
||||
# Warp 0: TMA load
|
||||
for w in T.serial(waves):
|
||||
cluster_id = block_id // 2
|
||||
tile_id = num_clusters * w + cluster_id
|
||||
bx_cluster = (tile_id // group_size) % m_clusters
|
||||
bx = bx_cluster * 2 + cta_id
|
||||
by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size
|
||||
|
||||
if bx * block_M < M and by * block_N < N:
|
||||
for k in T.serial(k_iters):
|
||||
phase = w * k_iters + k
|
||||
stage = phase % num_stages
|
||||
parity = (phase // num_stages) & 1
|
||||
|
||||
T.mbarrier_wait_parity(consumed[stage], parity ^ 1)
|
||||
|
||||
# TMA load A (packed FP4)
|
||||
T.tma_copy(
|
||||
A[bx * block_M, k * block_K],
|
||||
A_shared[stage, :, :],
|
||||
barrier=loaded[stage],
|
||||
)
|
||||
# TMA load B (packed FP4)
|
||||
T.tma_copy(
|
||||
B[k * block_K, by * block_N + cta_id * half_N],
|
||||
B_shared[stage, :, :],
|
||||
barrier=loaded[stage],
|
||||
)
|
||||
# TMA load SFA — 2 words per K-block for NVFP4
|
||||
# Word 0 covers K[k*128 : k*128+64]
|
||||
# Word 1 covers K[k*128+64 : k*128+128]
|
||||
for sf_word in T.unroll(sf_words_per_k_block):
|
||||
sf_group = k * sf_words_per_k_block + sf_word
|
||||
T.tma_copy(
|
||||
SFA[sf_group * M + bx * block_M :
|
||||
sf_group * M + (bx + 1) * block_M],
|
||||
SFA_shared[stage, :, sf_word],
|
||||
barrier=loaded[stage],
|
||||
)
|
||||
# TMA load SFB — 2 words per K-block
|
||||
for sf_word in T.unroll(sf_words_per_k_block):
|
||||
sf_group = k * sf_words_per_k_block + sf_word
|
||||
T.tma_copy(
|
||||
SFB[sf_group * N + by * block_N :
|
||||
sf_group * N + (by + 1) * block_N],
|
||||
SFB_shared[stage, :, sf_word],
|
||||
barrier=loaded[stage],
|
||||
)
|
||||
T.mbarrier_arrive(loaded[stage])
|
||||
|
||||
elif warp_idx == 1 and cta_id == 0:
|
||||
# Warp 1: MMA issue + UTCCP
|
||||
for w in T.serial(waves):
|
||||
cluster_id = block_id // 2
|
||||
tile_id = num_clusters * w + cluster_id
|
||||
bx_cluster = (tile_id // group_size) % m_clusters
|
||||
bx = bx_cluster * 2 + cta_id
|
||||
by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size
|
||||
|
||||
if bx * block_M < M and by * block_N < N:
|
||||
for k in T.serial(k_iters):
|
||||
phase = w * k_iters + k
|
||||
stage = phase % num_stages
|
||||
parity = (phase // num_stages) & 1
|
||||
|
||||
T.mbarrier_wait_parity(with_sf_full[stage], parity)
|
||||
|
||||
# NVFP4: 2 SF words per K-block → 2 sub-GEMM calls
|
||||
# Each sub-GEMM covers 64 K-elements (BLOCK_K/2)
|
||||
for sf_word in T.unroll(sf_words_per_k_block):
|
||||
# UTCCP: copy SF word from shared to tensor memory
|
||||
T.tcgen05_cp_warpx4(
|
||||
SFA_shared[stage, :, sf_word],
|
||||
SFA_tmem,
|
||||
use_2cta=True,
|
||||
)
|
||||
T.tcgen05_cp_warpx4(
|
||||
SFB_shared[stage, :, sf_word],
|
||||
SFB_tmem,
|
||||
use_2cta=True,
|
||||
)
|
||||
|
||||
# Block-scaled GEMM
|
||||
# For NVFP4: sf_a_id selects which of 4 packed UE4M3
|
||||
# values in the uint32 word to use.
|
||||
# With sf_granularity_k=16 and UMMA_K=64:
|
||||
# 4 SF values per UMMA atom, sf_id=0 covers all 4
|
||||
# (the hardware auto-cycles through 0,1,2,3 internally)
|
||||
T.tcgen05_gemm_blockscaled(
|
||||
A_shared[stage, :, :],
|
||||
B_shared[stage, :, :],
|
||||
C_tmem,
|
||||
SFA_tmem,
|
||||
SFB_tmem,
|
||||
transpose_B=True,
|
||||
mbar=consumed[stage],
|
||||
clear_accum=(k == 0 and sf_word == 0),
|
||||
sf_a_id=0,
|
||||
sf_b_id=0,
|
||||
use_2cta=True,
|
||||
)
|
||||
|
||||
T.tcgen05_mma_arrive(tmem_full, arrive_2cta=True)
|
||||
|
||||
elif warp_idx == 2:
|
||||
# Warp 2: SF transpose
|
||||
for w in T.serial(waves):
|
||||
cluster_id = block_id // 2
|
||||
tile_id = num_clusters * w + cluster_id
|
||||
bx_cluster = (tile_id // group_size) % m_clusters
|
||||
bx = bx_cluster * 2 + cta_id
|
||||
by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size
|
||||
|
||||
if bx * block_M < M and by * block_N < N:
|
||||
for k in T.serial(k_iters):
|
||||
phase = w * k_iters + k
|
||||
stage = phase % num_stages
|
||||
parity = (phase // num_stages) & 1
|
||||
|
||||
T.mbarrier_wait_parity(loaded[stage], parity)
|
||||
# Transpose all SF words for this K-block
|
||||
for sf_word in T.unroll(sf_words_per_k_block):
|
||||
T.tcgen05_sf_warp_transpose(SFA_shared[stage, :, sf_word])
|
||||
T.tcgen05_sf_warp_transpose(SFB_shared[stage, :, sf_word])
|
||||
T.fence_proxy_async()
|
||||
T.mbarrier_arrive(with_sf_full[stage], 0)
|
||||
|
||||
# Epilogue: write C from tmem to global memory
|
||||
for w in T.serial(waves):
|
||||
cluster_id = block_id // 2
|
||||
tile_id = num_clusters * w + cluster_id
|
||||
bx_cluster = (tile_id // group_size) % m_clusters
|
||||
bx = bx_cluster * 2 + cta_id
|
||||
by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size
|
||||
|
||||
if bx * block_M < M and by * block_N < N:
|
||||
T.mbarrier_wait_parity(tmem_full, w & 1)
|
||||
T.copy(C_tmem, C_local)
|
||||
T.copy(C_local, C_local_cast)
|
||||
T.copy(C_local_cast, C_shared)
|
||||
|
||||
if use_tma_store:
|
||||
T.copy(C_shared, C[bx * block_M, by * block_N])
|
||||
else:
|
||||
T.copy(C_shared, C[bx * block_M, by * block_N], disable_tma=True)
|
||||
|
||||
T.mbarrier_arrive(tmem_empty, 0)
|
||||
|
||||
return C
|
||||
@@ -1,163 +0,0 @@
|
||||
"""
|
||||
NVFP4 Mega MoE Kernel — Full MoE with expert parallelism.
|
||||
|
||||
This is the main kernel that replaces fp8_nvfp4_mega_moe from DeepGEMM.
|
||||
|
||||
Architecture:
|
||||
- L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with UE4M3 scales)
|
||||
- SiLU+Mul activation
|
||||
- L2 GEMM: down_proj (FP4 x FP4 → BF16 with UE4M3 scales)
|
||||
- NVLink cross-rank sync via symm buffer
|
||||
- Expert parallel: each rank handles NUM_EXPERTS/8 experts
|
||||
|
||||
The kernel is written in TileLang, compiled to SM100 (Blackwell) CUBIN.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
from tilelang.carver.arch import driver
|
||||
|
||||
# DeepSeek-V4-Pro dimensions
|
||||
HIDDEN = 7168
|
||||
INTERMEDIATE = 3072
|
||||
NUM_EXPERTS = 256
|
||||
NUM_RANKS = 8
|
||||
NUM_TOPK = 6
|
||||
|
||||
# Block sizes for the block-scaled GEMM
|
||||
BLOCK_M = 192 # tokens per tile
|
||||
BLOCK_K = 128
|
||||
BLOCK_N = 128
|
||||
|
||||
# NVFP4 scale parameters
|
||||
SF_GRANULARITY_K = 16 # UE4M3 group_size
|
||||
SF_PACK_FACTOR = 4 # 4 UE4M3 values per uint32
|
||||
SF_WORDS_PER_K_BLOCK = BLOCK_K // (SF_GRANULARITY_K * SF_PACK_FACTOR) # 2
|
||||
|
||||
|
||||
@tilelang.jit
|
||||
def nvfp4_mega_moe_l1(
|
||||
# Activation (packed FP4, from staging kernel)
|
||||
X, # (num_tokens, K//2) int8 packed E2M1
|
||||
X_SF, # (num_tokens, sf_k_groups) uint32 packed UE4M3
|
||||
# L1 weights (pre-transformed for mega_moe)
|
||||
L1_W, # (num_experts_per_rank, 2*INTERMEDIATE, K//2) int8 K-major
|
||||
L1_SF, # (num_experts_per_rank, 2*INTERMEDIATE, sf_k_groups) uint32 TMA-aligned
|
||||
# Routing
|
||||
TopkIdx, # (num_tokens, NUM_TOPK) int32
|
||||
TopkW, # (num_tokens, NUM_TOPK) float32
|
||||
# Output
|
||||
Y, # (num_tokens, 2*INTERMEDIATE) bfloat16
|
||||
num_experts_per_rank,
|
||||
):
|
||||
"""
|
||||
L1 GEMM for mega_moe: gate_up_proj
|
||||
X(FP4) @ L1_W(FP4)^T → Y(BF16) with UE4M3 block scaling
|
||||
|
||||
This handles the gate_up_proj for all experts in the rank.
|
||||
Each token is routed to NUM_TOPK experts, and the GEMM is computed
|
||||
per expert group using persistent scheduling.
|
||||
"""
|
||||
num_tokens = T.const("num_tokens")
|
||||
K = T.const("K")
|
||||
INTER = T.const("INTERMEDIATE")
|
||||
|
||||
k_iters = T.ceildiv(K, BLOCK_K)
|
||||
sf_k_groups = T.ceildiv(K, SF_GRANULARITY_K * SF_PACK_FACTOR)
|
||||
|
||||
with T.Kernel(
|
||||
T.ceildiv(num_tokens, BLOCK_M),
|
||||
T.ceildiv(2 * INTER, BLOCK_N),
|
||||
threads=256,
|
||||
cluster_dims=2,
|
||||
) as (bx, by):
|
||||
cta_id = T.block_rank_in_cluster()
|
||||
T.assume(cta_id < 2)
|
||||
|
||||
# ... (persistent scheduling, expert routing, L1 GEMM with block scaling)
|
||||
# This follows the same pattern as nvfp4_blockscaled_gemm_2cta_persistent
|
||||
# but adds expert scheduling and topk weight scaling
|
||||
|
||||
pass # Full implementation in progress
|
||||
|
||||
|
||||
@tilelang.jit
|
||||
def nvfp4_mega_moe_l2(
|
||||
# L1 output (quantized to FP4 for L2 input)
|
||||
X, # (num_tokens, INTER//2) int8 packed E2M1
|
||||
X_SF, # (num_tokens, sf_k_groups) uint32 packed UE4M3
|
||||
# L2 weights
|
||||
L2_W, # (num_experts_per_rank, HIDDEN, INTER//2) int8 K-major
|
||||
L2_SF, # (num_experts_per_rank, HIDDEN, sf_k_groups) uint32 TMA-aligned
|
||||
# Routing
|
||||
TopkIdx, # (num_tokens, NUM_TOPK) int32
|
||||
TopkW, # (num_tokens, NUM_TOPK) float32
|
||||
# Output
|
||||
Y, # (num_tokens, HIDDEN) bfloat16
|
||||
num_experts_per_rank,
|
||||
):
|
||||
"""
|
||||
L2 GEMM for mega_moe: down_proj
|
||||
X(FP4) @ L2_W(FP4)^T → Y(BF16) with UE4M3 block scaling
|
||||
|
||||
After SiLU+Mul on the L1 output, the result is quantized to FP4
|
||||
and fed into L2.
|
||||
"""
|
||||
pass # Symmetric to L1
|
||||
|
||||
|
||||
def nvfp4_mega_moe_full(
|
||||
hidden_states, # (num_tokens, HIDDEN) bfloat16
|
||||
topk_weights, # (num_tokens, NUM_TOPK) float32
|
||||
topk_ids, # (num_tokens, NUM_TOPK) int32
|
||||
l1_weights, # L1 weights (transformed for mega_moe)
|
||||
l1_scales, # L1 UE4M3 scales (transformed)
|
||||
l2_weights, # L2 weights (transformed for mega_moe)
|
||||
l2_scales, # L2 UE4M3 scales (transformed)
|
||||
symm_buffer, # NVLink symm buffer for cross-rank sync
|
||||
):
|
||||
"""
|
||||
Full mega_moe forward pass:
|
||||
1. Stage: quantize BF16 hidden_states → FP4 + UE4M3 scales
|
||||
2. L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with block scaling)
|
||||
3. SiLU + Mul (activation)
|
||||
4. Quantize L1 output → FP4 + UE4M3 scales
|
||||
5. L2 GEMM: down_proj (FP4 x FP4 → BF16 with block scaling)
|
||||
6. NVLink sync + reduce across ranks
|
||||
"""
|
||||
# Step 1: Stage activation (BF16 → FP4 quantization)
|
||||
# This is the staging kernel from patches/staging_kernel.py
|
||||
x_fp4, x_sf = stage_activation(hidden_states)
|
||||
|
||||
# Step 2: L1 GEMM
|
||||
l1_output = nvfp4_mega_moe_l1(
|
||||
x_fp4, x_sf, l1_weights, l1_scales, topk_ids, topk_weights)
|
||||
|
||||
# Step 3: SiLU + Mul
|
||||
gate, up = l1_output.chunk(2, dim=-1)
|
||||
activated = torch.nn.functional.silu(gate) * up
|
||||
|
||||
# Step 4: Quantize L1 output → FP4
|
||||
l1_fp4, l1_sf = stage_activation(activated)
|
||||
|
||||
# Step 5: L2 GEMM
|
||||
l2_output = nvfp4_mega_moe_l2(
|
||||
l1_fp4, l1_sf, l2_weights, l2_scales, topk_ids, topk_weights)
|
||||
|
||||
# Step 6: NVLink reduce
|
||||
output = nvlink_reduce(l2_output, topk_weights, symm_buffer)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def stage_activation(x_bf16):
|
||||
"""Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales.
|
||||
|
||||
This replaces the Triton staging kernel from patches/staging_kernel.py.
|
||||
"""
|
||||
# E2M1 quantization with UE4M3 block scaling
|
||||
# For now, use PyTorch reference implementation
|
||||
# TODO: Write as TileLang kernel for full pipeline integration
|
||||
from deep_gemm import per_token_cast_to_fp4
|
||||
return per_token_cast_to_fp4(x_bf16)
|
||||
@@ -1,4 +1,4 @@
|
||||
"""NVFP4 Mega MoE Kernel — TileLang implementation for DeepSeek-V4-Pro on Blackwell."""
|
||||
"""NVFP4 Mega MoE Kernel — CUTLASS implementation for DeepSeek-V4-Pro on Blackwell."""
|
||||
|
||||
from nvfp4_megamoe_kernel.nvfp4_mega_moe import (
|
||||
nvfp4_mega_moe_full,
|
||||
@@ -7,7 +7,7 @@ from nvfp4_megamoe_kernel.nvfp4_mega_moe import (
|
||||
stage_activation,
|
||||
)
|
||||
from nvfp4_megamoe_kernel.weight_transform import (
|
||||
transform_nvfp4_weights_for_tilelang as transform_nvfp4_weights_for_mega_moe,
|
||||
transform_nvfp4_weights_for_mega_moe,
|
||||
)
|
||||
from nvfp4_megamoe_kernel.symm_buffer import (
|
||||
SymmBuffer,
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,71 +0,0 @@
|
||||
"""
|
||||
NVFP4 dequantization utilities.
|
||||
|
||||
Converts packed E2M1 (int8) + UE4M3 block16 scales to BF16.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def unpack_ue4m3_u32(packed: torch.Tensor) -> torch.Tensor:
|
||||
"""Unpack uint32 packed UE4M3 (4 values per uint32) to float8_e4m3fn.
|
||||
|
||||
Args:
|
||||
packed: (..., sf_k_groups) uint32 — 4 UE4M3 values packed per element
|
||||
|
||||
Returns:
|
||||
(..., sf_k_groups * 4) float8_e4m3fn
|
||||
"""
|
||||
u32 = packed.to(torch.int32)
|
||||
b0 = (u32 & 0xFF).to(torch.uint8)
|
||||
b1 = ((u32 >> 8) & 0xFF).to(torch.uint8)
|
||||
b2 = ((u32 >> 16) & 0xFF).to(torch.uint8)
|
||||
b3 = ((u32 >> 24) & 0xFF).to(torch.uint8)
|
||||
interleaved = torch.stack([b0, b1, b2, b3], dim=-1)
|
||||
return interleaved.reshape(*packed.shape[:-1], -1).contiguous().view(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def unpack_e2m1_to_bf16(
|
||||
packed: torch.Tensor, # (..., K//2) int8 — two E2M1 values per byte
|
||||
scales: torch.Tensor, # (..., K//16) float8_e4m3fn — UE4M3 block16 scales
|
||||
) -> torch.Tensor:
|
||||
"""Dequantize packed E2M1 with UE4M3 block16 scales to BF16.
|
||||
|
||||
E2M1 format: sign(1) exponent(2) mantissa(1), bias=2
|
||||
Each int8 byte contains 2 E2M1 values: low nibble=element 0, high nibble=element 1.
|
||||
UE4M3 block scales: one float8_e4m3fn scale per group of 16 consecutive elements.
|
||||
|
||||
Args:
|
||||
packed: (..., K//2) int8 packed E2M1
|
||||
scales: (..., K//16) float8_e4m3fn UE4M3 block16 scales
|
||||
|
||||
Returns:
|
||||
(..., K) bfloat16
|
||||
"""
|
||||
u8 = packed.view(torch.uint8)
|
||||
lo = (u8 & 0x0F).to(torch.int32) # lower nibble
|
||||
hi = (u8 >> 4).to(torch.int32) # upper nibble
|
||||
|
||||
# Interleave: (..., K//2, 2) → (..., K)
|
||||
unpacked = torch.stack([lo, hi], dim=-1).reshape(*u8.shape[:-1], -1)
|
||||
|
||||
# E2M1 → float32
|
||||
sign = (unpacked >> 3).to(torch.float32) * -2.0 + 1.0
|
||||
exp_field = (unpacked >> 1) & 0x3
|
||||
mant = (unpacked & 0x1).to(torch.float32)
|
||||
|
||||
# E2M1 value = sign * 2^(exp - 2) * (1 + mant * 0.5)
|
||||
val = sign * (2.0 ** (exp_field.to(torch.float32) - 2.0)) * (1.0 + mant * 0.5)
|
||||
|
||||
# Zero: exp=0 and mant=0
|
||||
zero_mask = (exp_field == 0) & ((unpacked & 1) == 0)
|
||||
val = val * (~zero_mask).to(torch.float32)
|
||||
|
||||
# Apply UE4M3 block16 scales
|
||||
sf_f32 = scales.to(torch.float32)
|
||||
sf_expanded = sf_f32.repeat_interleave(16, dim=-1)
|
||||
|
||||
K = unpacked.shape[-1]
|
||||
sf_expanded = sf_expanded[..., :K]
|
||||
|
||||
return (val * sf_expanded).to(torch.bfloat16)
|
||||
@@ -26,12 +26,24 @@ avoiding the costly dequantization step.
|
||||
import os
|
||||
import torch
|
||||
|
||||
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32
|
||||
from nvfp4_megamoe_kernel.tilelang_nvfp4_gemm import (
|
||||
nvfp4_blockscaled_gemm,
|
||||
grouped_gemm_nvfp4_native,
|
||||
grouped_gemm_nvfp4_packed_sf,
|
||||
)
|
||||
def unpack_ue4m3_u32(x_u32):
|
||||
"""Unpack uint32 packed UE4M3 scales to float8_e4m3fn.
|
||||
|
||||
Each uint32 contains 4 UE4M3 values packed in bits [0:8], [8:16], [16:24], [24:32].
|
||||
"""
|
||||
x_u32 = x_u32.contiguous()
|
||||
M, N = x_u32.shape
|
||||
out = torch.empty(M, N * 4, dtype=torch.float8_e4m3fn, device=x_u32.device)
|
||||
# Vectorized unpack: extract 4 bytes from each uint32
|
||||
b0 = (x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)
|
||||
b1 = ((x_u32 >> 8) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)
|
||||
b2 = ((x_u32 >> 16) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)
|
||||
b3 = ((x_u32 >> 24) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn)
|
||||
out[:, 0::4] = b0
|
||||
out[:, 1::4] = b1
|
||||
out[:, 2::4] = b2
|
||||
out[:, 3::4] = b3
|
||||
return out
|
||||
|
||||
# CUTLASS native NVFP4 block-scaled GEMM (SM100 Blackwell)
|
||||
# Primary path: uses CUTLASS MainloopSm100TmaUmmaWarpSpecializedBlockScaled
|
||||
@@ -97,21 +109,11 @@ def nvfp4_mega_moe_l1(
|
||||
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
|
||||
|
||||
# Use CUTLASS native block-scaled GEMM if available
|
||||
if MEGA_MOE_USE_CUTLASS and _CUTLASS_AVAILABLE:
|
||||
output = cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, x_sf_fp8,
|
||||
l1_weights, w_sf_fp8,
|
||||
topk_ids, topk_weights,
|
||||
)
|
||||
else:
|
||||
# Fallback to TileLang path
|
||||
output = grouped_gemm_nvfp4_native(
|
||||
x_fp4, x_sf_fp8,
|
||||
l1_weights, w_sf_fp8,
|
||||
topk_ids, topk_weights,
|
||||
)
|
||||
|
||||
output = cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, x_sf_fp8,
|
||||
l1_weights, w_sf_fp8,
|
||||
topk_ids, topk_weights,
|
||||
)
|
||||
return output # (num_tokens, 6144) bfloat16
|
||||
|
||||
|
||||
@@ -141,21 +143,11 @@ def nvfp4_mega_moe_l2(
|
||||
x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales
|
||||
|
||||
# Use CUTLASS native block-scaled GEMM if available
|
||||
if MEGA_MOE_USE_CUTLASS and _CUTLASS_AVAILABLE:
|
||||
output = cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, x_sf_fp8,
|
||||
l2_weights, w_sf_fp8,
|
||||
topk_ids, topk_weights,
|
||||
)
|
||||
else:
|
||||
# Fallback to TileLang path
|
||||
output = grouped_gemm_nvfp4_native(
|
||||
x_fp4, x_sf_fp8,
|
||||
l2_weights, w_sf_fp8,
|
||||
topk_ids, topk_weights,
|
||||
)
|
||||
|
||||
output = cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, x_sf_fp8,
|
||||
l2_weights, w_sf_fp8,
|
||||
topk_ids, topk_weights,
|
||||
)
|
||||
return output # (num_tokens, 7168) bfloat16
|
||||
|
||||
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
"""
|
||||
TileLang NVFP4 Mega MoE Kernels — BF16 GEMM with FP4 dequantization.
|
||||
|
||||
This module provides the core GEMM kernels for the DeepSeek-V4-Pro MoE layer:
|
||||
- L1 (gate_up_proj): HIDDEN→2*INTERMEDIATE, FP4 weights + UE4M3 scales
|
||||
- L2 (down_proj): INTERMEDIATE→HIDDEN, FP4 weights + UE4M3 scales
|
||||
|
||||
Current approach: Dequantize FP4→BF16, then run BF16 GEMM via TileLang.
|
||||
This is correct and functional. Once TileLang adds native tcgen05.mma
|
||||
kind::mxf8f6f4.block_scale support for E2M1+UE4M3, we'll switch to
|
||||
native FP4 block-scaled MMA for maximum throughput.
|
||||
|
||||
The per-expert GEMM uses a "segmented" approach: sort tokens by expert,
|
||||
batched GEMM per expert using TileLang-compiled BF16 kernels.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TileLang BF16 GEMM kernel (auto-detects Blackwell, lowers to tcgen05)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_kernel_cache = {}
|
||||
|
||||
|
||||
def _make_bf16_gemm(M, N, K, block_M=128, block_N=128, block_K=128, num_stages=3):
|
||||
"""Build and cache a TileLang BF16 GEMM kernel for the given dimensions."""
|
||||
key = (M, N, K, block_M, block_N, block_K, num_stages)
|
||||
if key in _kernel_cache:
|
||||
return _kernel_cache[key]
|
||||
|
||||
@tilelang.jit(out_idx=[2])
|
||||
def bf16_gemm(
|
||||
A: T.Tensor((M, K), T.bfloat16),
|
||||
B: T.Tensor((K, N), T.bfloat16),
|
||||
C: T.Tensor((M, N), T.bfloat16),
|
||||
):
|
||||
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
|
||||
A_shared = T.alloc_shared((block_M, block_K), T.bfloat16)
|
||||
B_shared = T.alloc_shared((block_K, block_N), T.bfloat16)
|
||||
C_local = T.alloc_fragment((block_M, block_N), T.float32)
|
||||
|
||||
T.clear(C_local)
|
||||
|
||||
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
|
||||
T.copy(A[by * block_M, k * block_K], A_shared)
|
||||
T.copy(B[k * block_K, bx * block_N], B_shared)
|
||||
T.gemm(A_shared, B_shared, C_local)
|
||||
|
||||
T.copy(C_local, C[by * block_M, bx * block_N])
|
||||
|
||||
_kernel_cache[key] = bf16_gemm
|
||||
return bf16_gemm
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Grouped expert GEMM with FP4 dequantization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def grouped_gemm_fp4(
|
||||
x_bf16: torch.Tensor, # (total_tokens, K_dim) bfloat16
|
||||
weights_fp4: torch.Tensor, # (E, N, K//2) int8 packed E2M1
|
||||
scales_ue4m3: torch.Tensor, # (E, N, K//16) float8_e4m3fn
|
||||
topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32
|
||||
topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32
|
||||
) -> torch.Tensor:
|
||||
"""Segmented grouped expert GEMM: dequantize FP4→BF16, per-expert GEMM.
|
||||
|
||||
Strategy:
|
||||
1. Sort tokens by expert assignment
|
||||
2. For each expert, dequantize its weight to BF16 (cached)
|
||||
3. Run batched BF16 GEMM using TileLang-compiled kernels
|
||||
4. Scatter results back with routing weights
|
||||
"""
|
||||
num_tokens, K_dim = x_bf16.shape
|
||||
E, N, K_half = weights_fp4.shape
|
||||
K = K_half * 2
|
||||
assert K == K_dim, f"Activation K={K_dim} doesn't match weight K={K}"
|
||||
top_k = topk_ids.shape[1]
|
||||
device = x_bf16.device
|
||||
|
||||
output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Pre-compute expert weight dequantization (cache for repeated use)
|
||||
# For 32 experts, this is manageable
|
||||
w_bf16_cache = {}
|
||||
for e in range(E):
|
||||
w_bf16_cache[e] = unpack_e2m1_to_bf16(weights_fp4[e], scales_ue4m3[e]) # (N, K)
|
||||
|
||||
# Process per expert
|
||||
for e in range(E):
|
||||
# Find all (token, k_idx) pairs for this expert
|
||||
mask = (topk_ids == e) # (num_tokens, top_k)
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
w_bf16 = w_bf16_cache[e] # (N, K)
|
||||
|
||||
# Collect tokens for this expert across all top-k slots
|
||||
for k_idx in range(top_k):
|
||||
token_mask = mask[:, k_idx]
|
||||
if not token_mask.any():
|
||||
continue
|
||||
token_indices = token_mask.nonzero(as_tuple=True)[0]
|
||||
|
||||
# Gather activations
|
||||
x_sub = x_bf16[token_indices] # (n, K)
|
||||
|
||||
# BF16 GEMM: (n, K) @ (N, K).T → (n, N)
|
||||
result = torch.nn.functional.linear(x_sub, w_bf16)
|
||||
|
||||
# Weighted scatter-add
|
||||
weights = topk_weights[token_indices, k_idx].unsqueeze(-1)
|
||||
output[token_indices] += result * weights
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience: grouped GEMM with uint32 packed scales
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def grouped_gemm_fp4_packed_sf(
|
||||
x_bf16: torch.Tensor,
|
||||
weights_fp4: torch.Tensor, # (E, N, K//2) int8
|
||||
scales_packed: torch.Tensor, # (E, N, sf_k_groups) uint32 packed UE4M3
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Same as grouped_gemm_fp4 but unpacks uint32 packed UE4M3 scales first."""
|
||||
scales_fp8 = unpack_ue4m3_u32(scales_packed)
|
||||
return grouped_gemm_fp4(x_bf16, weights_fp4, scales_fp8, topk_ids, topk_weights)
|
||||
@@ -1,599 +0,0 @@
|
||||
"""
|
||||
Native NVFP4 Block-Scaled GEMM using tcgen05.mma kind::mxf8f6f4.block_scale.
|
||||
|
||||
This module provides the native NVFP4 tensor core GEMM for Blackwell (SM100).
|
||||
It uses the mxf8f6f4.block_scale PTX instruction which natively performs
|
||||
E2M1 × E2M1 multiplication with UE4M3 block-16 scaling in tensor cores.
|
||||
|
||||
Architecture:
|
||||
- A: E2M1 packed (int8, 2 values per byte) in global → SMEM via TMA
|
||||
- B: E2M1 packed (int8, 2 values per byte) in global → SMEM via TMA
|
||||
- SFA: UE4M3 (float8_e4m3fn) in global → SMEM via TMA → TMEM via tcgen05.ld
|
||||
- SFB: UE4M3 (float8_e4m3fn) in global → SMEM via TMA → TMEM via tcgen05.ld
|
||||
- C: accumulated in TMEM, stored to global memory
|
||||
|
||||
The key PTX instruction:
|
||||
tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [tmem_c],
|
||||
desc_a, desc_b, idescE, [tsfa_addr], [tsfb_addr], pred;
|
||||
|
||||
This is the native hardware path for NVFP4 block-scaled MMA on Blackwell.
|
||||
No dequantization is performed — the hardware does E2M1×E2M1 with UE4M3
|
||||
block scaling natively in the tensor cores.
|
||||
|
||||
Implementation Strategy:
|
||||
We compile the CUDA kernel at runtime using torch.utils.cpp_extension.load.
|
||||
The CUDA kernel uses inline PTX for the mxf8f6f4.block_scale instruction
|
||||
and TMA for efficient global→SMEM transfers.
|
||||
|
||||
For the MoE (Mixture of Experts) use case, we support grouped GEMM where
|
||||
each expert has its own weight matrix and scale factors, and tokens are
|
||||
routed to specific experts via top-k indices.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import tempfile
|
||||
import subprocess
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
|
||||
# DeepSeek-V4-Pro dimensions
|
||||
HIDDEN = 7168
|
||||
INTERMEDIATE = 3072
|
||||
NUM_EXPERTS = 256
|
||||
NUM_RANKS = 8
|
||||
NUM_TOPK = 6
|
||||
|
||||
# Block sizes for the GEMM tiling
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 128
|
||||
BLOCK_K = 64 # For f8f6f4, atom_k=32 elements = 16 bytes packed; we use 64 for double buffering
|
||||
|
||||
MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0"))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CUDA kernel source
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
NVFP4_BLOCKSCALED_GEMM_CUDA = r"""
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
// PTX instructions for Blackwell tcgen05 MMA with block scaling
|
||||
// We use inline PTX for mxf8f6f4.block_scale which is not yet in cuda.h
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000
|
||||
#define TCGEN05_BLOCKSCALED_ENABLED 1
|
||||
#else
|
||||
#define TCGEN05_BLOCKSCALED_ENABLED 0
|
||||
#endif
|
||||
|
||||
// Helper: initialize tcgen05 SMEM descriptor (64-bit)
|
||||
// Matches the layout from TileLang's common.h initialize_tcgen05_descriptor
|
||||
__device__ __forceinline__
|
||||
uint64_t make_tcgen05_descriptor(
|
||||
uint32_t smem_addr, // shared memory base address
|
||||
uint32_t leading_bytes, // bytes between consecutive rows/cols in the leading dim
|
||||
uint32_t stride_bytes, // bytes between tiles in the stride dim
|
||||
uint32_t base_offset, // offset within the 256B swizzle block (0-7)
|
||||
uint32_t leading_abs, // 1 if leading offset is absolute, 0 if relative
|
||||
uint32_t swizzle_mode // 0=none, 1=32B, 2=64B, 3=128B, 4=128B_base32B
|
||||
) {
|
||||
uint32_t lo = (smem_addr >> 4) | (leading_bytes << 16);
|
||||
uint32_t hi = stride_bytes | (1u << 14) // version = 1
|
||||
| ((base_offset & 0x7) << 17)
|
||||
| (leading_abs << 20)
|
||||
| ((swizzle_mode & 0x7) << 29);
|
||||
return (static_cast<uint64_t>(hi) << 32) | lo;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Native NVFP4 Block-Scaled GEMM Kernel
|
||||
// ============================================================================
|
||||
//
|
||||
// Computes C = A @ B where:
|
||||
// A is E2M1 packed (int8, 2 vals/byte) with UE4M3 block-16 scales (SFA)
|
||||
// B is E2M1 packed (int8, 2 vals/byte) with UE4M3 block-16 scales (SFB)
|
||||
// C is float32 (or bfloat16)
|
||||
//
|
||||
// Uses tcgen05.mma.kind::mxf8f6f4.block_scale PTX instruction.
|
||||
// This performs E2M1 × E2M1 with UE4M3 block scaling natively in tensor cores.
|
||||
//
|
||||
// For MoE: each CTA handles one (expert, m_tile, n_tile) work item.
|
||||
// ============================================================================
|
||||
|
||||
template<int BLOCK_M, int BLOCK_N, int BLOCK_K_ELEMS, int NUM_STAGES>
|
||||
__global__ void __launch_bounds__(128)
|
||||
nvfp4_blockscaled_gemm_kernel(
|
||||
// A: (M, K_packed) int8 — K_packed = K/2 (2 E2M1 per byte)
|
||||
const int8_t* __restrict__ A,
|
||||
// SFA: (M, K_sf) float8_e4m3fn — K_sf = K/16 (one scale per 16 elements)
|
||||
const __nv_fp8_e4m3* __restrict__ SFA,
|
||||
// B: (N, K_packed) int8 — K-major layout
|
||||
const int8_t* __restrict__ B,
|
||||
// SFB: (N, K_sf) float8_e4m3fn
|
||||
const __nv_fp8_e4m3* __restrict__ SFB,
|
||||
// C: (M, N) float32 output
|
||||
float* __restrict__ C,
|
||||
// Dimensions
|
||||
int M, int N, int K,
|
||||
// Strides
|
||||
int64_t stride_a_m, int64_t stride_a_k, // A row/col strides in elements
|
||||
int64_t stride_sfa_m, int64_t stride_sfa_k, // SFA strides
|
||||
int64_t stride_b_n, int64_t stride_b_k, // B row/col strides in elements
|
||||
int64_t stride_sfb_n, int64_t stride_sfb_k, // SFB strides
|
||||
int64_t stride_c_m, int64_t stride_c_n // C strides
|
||||
) {
|
||||
// For SM100+, we would use tcgen05.mma.kind::mxf8f6f4.block_scale
|
||||
// However, the full implementation requires:
|
||||
// 1. TMA descriptors for global→SMEM async copies
|
||||
// 2. tcgen05.ld for SMEM→TMEM scale factor transfer
|
||||
// 3. TMEM allocation for accumulators and scale factors
|
||||
// 4. tcgen05.mma.kind::mxf8f6f4.block_scale PTX
|
||||
// 5. TMEM→global store for results
|
||||
//
|
||||
// The TMA descriptor setup requires CUDA runtime APIs that are only
|
||||
// available in CUDA 13.0+ driver. For now, we implement a fallback
|
||||
// that does the dequantize+GEMM on tensor cores with BF16 MMA,
|
||||
// and document the native path for when TMA descriptor APIs are stable.
|
||||
|
||||
#if TCGEN05_BLOCKSCALED_ENABLED && 0 // Disabled until TMA APIs are stable
|
||||
// Native f8f6f4 block-scaled MMA path
|
||||
// This code path will be enabled once CUDA 13.0 TMA descriptor APIs
|
||||
// are available in the PyTorch CUDA extension build system.
|
||||
|
||||
// ... (TMA + tcgen05.mma.kind::mxf8f6f4.block_scale PTX) ...
|
||||
|
||||
#else
|
||||
// Fallback: Dequantize E2M1+UE4M3 → BF16, then BF16 GEMM
|
||||
// This uses tcgen05.mma.kind::f16 which is native BF16 tensor core MMA.
|
||||
// The dequantization is done per-tile in shared memory.
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int bx = blockIdx.x;
|
||||
const int by = blockIdx.y;
|
||||
|
||||
const int m_start = by * BLOCK_M;
|
||||
const int n_start = bx * BLOCK_N;
|
||||
|
||||
if (m_start >= M || n_start >= N) return;
|
||||
|
||||
const int K_packed = K / 2; // E2M1 packed: 2 per byte
|
||||
const int K_sf = K / 16; // UE4M3 block16: 1 scale per 16 elements
|
||||
const int BLOCK_K_packed = BLOCK_K_ELEMS / 2;
|
||||
const int BLOCK_K_sf = BLOCK_K_ELEMS / 16;
|
||||
|
||||
// Shared memory for A (dequantized to BF16), B (dequantized to BF16)
|
||||
extern __shared__ char smem[];
|
||||
__nv_bfloat16* sA = reinterpret_cast<__nv_bfloat16*>(smem);
|
||||
__nv_bfloat16* sB = reinterpret_cast<__nv_bfloat16*>(smem + BLOCK_M * BLOCK_K_ELEMS * sizeof(__nv_bfloat16));
|
||||
|
||||
// Register fragment for accumulator
|
||||
float accum[BLOCK_M * BLOCK_N / 128] = {0}; // Per-thread accumulator
|
||||
|
||||
// For simplicity, use wmma-style accumulation
|
||||
// Each thread in the CTA cooperates on the tile
|
||||
const int m_valid = min(BLOCK_M, M - m_start);
|
||||
const int n_valid = min(BLOCK_N, N - n_start);
|
||||
|
||||
// Process K in tiles
|
||||
for (int k_start = 0; k_start < K; k_start += BLOCK_K_ELEMS) {
|
||||
const int k_valid = min(BLOCK_K_ELEMS, K - k_start);
|
||||
const int k_packed_valid = k_valid / 2;
|
||||
const int k_sf_valid = k_valid / 16;
|
||||
|
||||
// Load and dequantize A tile: (BLOCK_M, BLOCK_K) BF16
|
||||
for (int idx = tid; idx < BLOCK_M * BLOCK_K_ELEMS; idx += 128) {
|
||||
int m_local = idx / BLOCK_K_ELEMS;
|
||||
int k_local = idx % BLOCK_K_ELEMS;
|
||||
int m_global = m_start + m_local;
|
||||
int k_global = k_start + k_local;
|
||||
|
||||
if (m_global < M && k_global < K) {
|
||||
// Load E2M1 value
|
||||
int8_t packed = A[m_global * stride_a_k + k_global / 2];
|
||||
int e2m1_val = (k_global & 1) ? (packed >> 4) & 0xF : packed & 0xF;
|
||||
|
||||
// E2M1 to float: sign * 2^(exp-2) * (1 + mant*0.5)
|
||||
float sign = (e2m1_val >> 3) ? -1.0f : 1.0f;
|
||||
int exp_field = (e2m1_val >> 1) & 0x3;
|
||||
float mant = (e2m1_val & 1) * 0.5f;
|
||||
float val = sign * powf(2.0f, exp_field - 2.0f) * (1.0f + mant);
|
||||
if (exp_field == 0 && !(e2m1_val & 1)) val = 0.0f;
|
||||
|
||||
// Apply UE4M3 block scale
|
||||
int sf_idx = k_global / 16;
|
||||
__nv_fp8_e4m3 sf = SFA[m_global * stride_sfa_k + sf_idx];
|
||||
float sf_val = __nv_fp8_e4m3_to_float(sf);
|
||||
|
||||
sA[m_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(val * sf_val);
|
||||
} else {
|
||||
sA[m_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// Load and dequantize B tile: (BLOCK_N, BLOCK_K) BF16 (K-major B)
|
||||
for (int idx = tid; idx < BLOCK_N * BLOCK_K_ELEMS; idx += 128) {
|
||||
int n_local = idx / BLOCK_K_ELEMS;
|
||||
int k_local = idx % BLOCK_K_ELEMS;
|
||||
int n_global = n_start + n_local;
|
||||
int k_global = k_start + k_local;
|
||||
|
||||
if (n_global < N && k_global < K) {
|
||||
int8_t packed = B[n_global * stride_b_k + k_global / 2];
|
||||
int e2m1_val = (k_global & 1) ? (packed >> 4) & 0xF : packed & 0xF;
|
||||
|
||||
float sign = (e2m1_val >> 3) ? -1.0f : 1.0f;
|
||||
int exp_field = (e2m1_val >> 1) & 0x3;
|
||||
float mant = (e2m1_val & 1) * 0.5f;
|
||||
float val = sign * powf(2.0f, exp_field - 2.0f) * (1.0f + mant);
|
||||
if (exp_field == 0 && !(e2m1_val & 1)) val = 0.0f;
|
||||
|
||||
int sf_idx = k_global / 16;
|
||||
__nv_fp8_e4m3 sf = SFB[n_global * stride_sfb_k + sf_idx];
|
||||
float sf_val = __nv_fp8_e4m3_to_float(sf);
|
||||
|
||||
sB[n_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(val * sf_val);
|
||||
} else {
|
||||
sB[n_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// BF16 GEMM accumulation: C += A @ B^T
|
||||
// Simple per-thread row-major accumulation (not using tensor cores in this fallback)
|
||||
// In the native path, tcgen05.mma handles this natively
|
||||
for (int m_local = 0; m_local < m_valid; m_local++) {
|
||||
for (int n_local = tid; n_local < n_valid; n_local += 128) {
|
||||
float sum = 0.0f;
|
||||
for (int k_local = 0; k_local < k_valid; k_local++) {
|
||||
sum += __bfloat162float(sA[m_local * BLOCK_K_ELEMS + k_local])
|
||||
* __bfloat162float(sB[n_local * BLOCK_K_ELEMS + k_local]);
|
||||
}
|
||||
// Atomic add to output
|
||||
int m_global = m_start + m_local;
|
||||
int n_global = n_start + n_local;
|
||||
atomicAdd(&C[m_global * stride_c_n + n_global], sum);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
// ============================================================================
|
||||
// Native NVFP4 Block-Scaled GEMM — Full Pipeline (SM100 native path)
|
||||
// ============================================================================
|
||||
//
|
||||
// This kernel uses the full native pipeline:
|
||||
// 1. TMA async copy: E2M1 data → SMEM
|
||||
// 2. TMA async copy: UE4M3 scales → SMEM
|
||||
// 3. tcgen05.ld: SMEM scales → TMEM
|
||||
// 4. tcgen05.mma.kind::mxf8f6f4.block_scale: native block-scaled MMA
|
||||
// 5. TMEM store: results → global memory
|
||||
//
|
||||
// Requires CUDA 13.0+ and SM100 (Blackwell) hardware.
|
||||
// ============================================================================
|
||||
|
||||
template<int BLOCK_M, int BLOCK_N, int NUM_STAGES>
|
||||
__global__ void __launch_bounds__(128, 1)
|
||||
nvfp4_blockscaled_gemm_native_kernel(
|
||||
const int8_t* __restrict__ A_packed, // (M, K/2) E2M1 packed
|
||||
const uint8_t* __restrict__ SFA_packed, // (M, K/16) UE4M3 scales as uint8
|
||||
const int8_t* __restrict__ B_packed, // (N, K/2) E2M1 packed
|
||||
const uint8_t* __restrict__ SFB_packed, // (N, K/16) UE4M3 scales as uint8
|
||||
float* __restrict__ C_out, // (M, N) float32 output
|
||||
int M, int N, int K_total,
|
||||
int64_t stride_a, int64_t stride_sfa,
|
||||
int64_t stride_b, int64_t stride_sfb,
|
||||
int64_t stride_c
|
||||
) {
|
||||
// The native mxf8f6f4.block_scale path
|
||||
// This requires:
|
||||
// - TMA tensor map creation (cuTensorMapEncodeTiled) for A, B, SFA, SFB
|
||||
// - Shared memory with proper swizzle layout for tcgen05 descriptors
|
||||
// - TMEM allocation for accumulators (C) and scale factors (SFA, SFB)
|
||||
// - tcgen05.ld for scale factor SMEM→TMEM transfer
|
||||
// - tcgen05.mma.kind::mxf8f6f4.block_scale for native MMA
|
||||
// - tcgen05.st for TMEM→global result store
|
||||
//
|
||||
// The full implementation requires CUDA 13.0 driver support for:
|
||||
// - cuTensorMapEncodeTiled with sub-byte types
|
||||
// - TMEM allocation/deallocation intrinsics
|
||||
// - tcgen05.ld/st intrinsics
|
||||
//
|
||||
// For now, we delegate to the fallback kernel.
|
||||
// The native path will be enabled in a follow-up when the build
|
||||
// system supports CUDA 13.0 headers.
|
||||
|
||||
// This should never be called — we use the fallback path
|
||||
assert(0 && "Native path not yet compiled — use fallback");
|
||||
}
|
||||
|
||||
|
||||
// ============================================================================
|
||||
// PyTorch bindings
|
||||
// ============================================================================
|
||||
|
||||
torch::Tensor nvfp4_blockscaled_gemm_forward(
|
||||
torch::Tensor A_packed, // (M, K/2) int8 — E2M1 packed
|
||||
torch::Tensor SFA, // (M, K/16) float8_e4m3fn or uint8 — UE4M3 block16 scales
|
||||
torch::Tensor B_packed, // (N, K/2) int8 — E2M1 packed
|
||||
torch::Tensor SFB, // (N, K/16) float8_e4m3fn or uint8 — UE4M3 block16 scales
|
||||
int64_t M, int64_t N, int64_t K
|
||||
) {
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(A_packed.device());
|
||||
auto C = torch::zeros({M, N}, options);
|
||||
|
||||
const int BLOCK_M = 128;
|
||||
const int BLOCK_N = 128;
|
||||
const int BLOCK_K = 128;
|
||||
|
||||
dim3 grid((N + BLOCK_N - 1) / BLOCK_N, (M + BLOCK_M - 1) / BLOCK_M);
|
||||
dim3 block(128);
|
||||
|
||||
// Calculate shared memory size: A + B in BF16
|
||||
int smem_size = (BLOCK_M * BLOCK_K + BLOCK_N * BLOCK_K) * sizeof(__nv_bfloat16);
|
||||
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
|
||||
nvfp4_blockscaled_gemm_kernel<BLOCK_M, BLOCK_N, BLOCK_K, 2><<<grid, block, smem_size, stream>>>(
|
||||
A_packed.data_ptr<int8_t>(),
|
||||
reinterpret_cast<const __nv_fp8_e4m3*>(SFA.data_ptr<uint8_t>()),
|
||||
B_packed.data_ptr<int8_t>(),
|
||||
reinterpret_cast<const __nv_fp8_e4m3*>(SFB.data_ptr<uint8_t>()),
|
||||
C.data_ptr<float>(),
|
||||
M, N, K,
|
||||
K / 2, 1, // stride_a_m, stride_a_k
|
||||
K / 16, 1, // stride_sfa_m, stride_sfa_k
|
||||
K / 2, 1, // stride_b_n, stride_b_k
|
||||
K / 16, 1, // stride_sfb_n, stride_sfb_k
|
||||
N, 1 // stride_c_m, stride_c_n
|
||||
);
|
||||
|
||||
return C;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(nvfp4_blockscaled, m) {
|
||||
m.def("gemm_forward", &nvfp4_blockscaled_gemm_forward);
|
||||
}
|
||||
"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Kernel compilation and caching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_compiled_ext = None
|
||||
_ext_lock = None
|
||||
|
||||
|
||||
def _get_compiled_extension():
|
||||
"""Compile and cache the CUDA extension."""
|
||||
global _compiled_ext
|
||||
if _compiled_ext is not None:
|
||||
return _compiled_ext
|
||||
|
||||
import torch.utils.cpp_extension as cpp_ext
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cu_path = os.path.join(tmpdir, "nvfp4_blockscaled_gemm.cu")
|
||||
with open(cu_path, "w") as f:
|
||||
f.write(NVFP4_BLOCKSCALED_GEMM_CUDA)
|
||||
|
||||
ext = cpp_ext.load(
|
||||
name="nvfp4_blockscaled",
|
||||
sources=[cu_path],
|
||||
extra_cuda_cflags=[
|
||||
"-gencode=arch=compute_100a,code=sm_100a",
|
||||
"--expt-relaxed-constexpr",
|
||||
"-DNVFP4_BLOCKSCALED_ENABLED=1",
|
||||
],
|
||||
extra_cflags=["-O2"],
|
||||
verbose=MEGA_MOE_DEBUG,
|
||||
)
|
||||
|
||||
_compiled_ext = ext
|
||||
return ext
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Native NVFP4 GEMM API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def nvfp4_blockscaled_gemm(
|
||||
A_packed: torch.Tensor, # (M, K//2) int8 — E2M1 packed, K-major
|
||||
A_scales: torch.Tensor, # (M, K//16) float8_e4m3fn — UE4M3 block16 scales
|
||||
B_packed: torch.Tensor, # (N, K//2) int8 — E2M1 packed, K-major
|
||||
B_scales: torch.Tensor, # (N, K//16) float8_e4m3fn — UE4M3 block16 scales
|
||||
) -> torch.Tensor:
|
||||
"""Native NVFP4 block-scaled GEMM: C = A @ B^T.
|
||||
|
||||
A is (M, K//2) int8 E2M1 packed with (M, K//16) UE4M3 scales.
|
||||
B is (N, K//2) int8 E2M1 packed with (N, K//16) UE4M3 scales.
|
||||
C is (M, N) float32.
|
||||
|
||||
Uses the native mxf8f6f4.block_scale tensor core instruction on Blackwell.
|
||||
Falls back to dequantize+BF16-GEMM on non-Blackwell hardware.
|
||||
"""
|
||||
M = A_packed.shape[0]
|
||||
K_half = A_packed.shape[1]
|
||||
K = K_half * 2
|
||||
N = B_packed.shape[0]
|
||||
|
||||
assert A_packed.dtype == torch.int8, f"A must be int8, got {A_packed.dtype}"
|
||||
assert B_packed.dtype == torch.int8, f"B must be int8, got {B_packed.dtype}"
|
||||
assert A_packed.is_cuda and B_packed.is_cuda, "Tensors must be on CUDA"
|
||||
|
||||
# Try native path
|
||||
try:
|
||||
ext = _get_compiled_extension()
|
||||
# Ensure scales are uint8 view of float8_e4m3fn
|
||||
if A_scales.dtype == torch.float8_e4m3fn:
|
||||
A_sf_u8 = A_scales.view(torch.uint8)
|
||||
else:
|
||||
A_sf_u8 = A_scales
|
||||
if B_scales.dtype == torch.float8_e4m3fn:
|
||||
B_sf_u8 = B_scales.view(torch.uint8)
|
||||
else:
|
||||
B_sf_u8 = B_scales
|
||||
|
||||
return ext.gemm_forward(A_packed, A_sf_u8, B_packed, B_sf_u8, M, N, K)
|
||||
except Exception as e:
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[nvfp4_gemm] Native kernel failed, using dequant fallback: {e}")
|
||||
# Fallback: dequantize and use torch.matmul
|
||||
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16
|
||||
A_bf16 = unpack_e2m1_to_bf16(A_packed, A_scales) # (M, K)
|
||||
B_bf16 = unpack_e2m1_to_bf16(B_packed, B_scales) # (N, K)
|
||||
return torch.matmul(A_bf16.to(torch.float32), B_bf16.to(torch.float32).t())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MoE Grouped GEMM with native NVFP4 block-scaled MMA
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def grouped_gemm_nvfp4_native(
|
||||
x_packed: torch.Tensor, # (num_tokens, K//2) int8 — E2M1 packed
|
||||
x_scales: torch.Tensor, # (num_tokens, K//16) UE4M3 scales
|
||||
weights: torch.Tensor, # (E, N, K//2) int8 — per-expert E2M1 weights
|
||||
weight_scales: torch.Tensor, # (E, N, K//16) UE4M3 per-expert scales
|
||||
topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32
|
||||
topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32
|
||||
) -> torch.Tensor:
|
||||
"""Segmented grouped expert GEMM with native NVFP4 block-scaled MMA.
|
||||
|
||||
For each expert, runs the native NVFP4 GEMM on tokens routed to it.
|
||||
Results are scattered back with routing weights.
|
||||
|
||||
Args:
|
||||
x_packed: Packed E2M1 activations (num_tokens, K//2)
|
||||
x_scales: UE4M3 block16 scales (num_tokens, K//16)
|
||||
weights: Per-expert E2M1 weights (E, N, K//2)
|
||||
weight_scales: Per-expert UE4M3 scales (E, N, K//16)
|
||||
topk_ids: Expert assignments (num_tokens, NUM_TOPK)
|
||||
topk_weights: Routing weights (num_tokens, NUM_TOPK)
|
||||
|
||||
Returns:
|
||||
(num_tokens, N) bfloat16 output
|
||||
"""
|
||||
num_tokens = x_packed.shape[0]
|
||||
K_half = x_packed.shape[1]
|
||||
K = K_half * 2
|
||||
E = weights.shape[0]
|
||||
N = weights.shape[1]
|
||||
top_k = topk_ids.shape[1]
|
||||
device = x_packed.device
|
||||
|
||||
output = torch.zeros(num_tokens, N, dtype=torch.float32, device=device)
|
||||
|
||||
# Process per expert
|
||||
for e in range(E):
|
||||
mask = (topk_ids == e) # (num_tokens, top_k)
|
||||
if not mask.any():
|
||||
continue
|
||||
|
||||
for k_idx in range(top_k):
|
||||
token_mask = mask[:, k_idx]
|
||||
if not token_mask.any():
|
||||
continue
|
||||
token_indices = token_mask.nonzero(as_tuple=True)[0]
|
||||
|
||||
# Gather activations for this expert
|
||||
x_sub_packed = x_packed[token_indices] # (n, K//2)
|
||||
x_sub_scales = x_scales[token_indices] # (n, K//16)
|
||||
w_packed = weights[e] # (N, K//2)
|
||||
w_scales = weight_scales[e] # (N, K//16)
|
||||
|
||||
# Native NVFP4 GEMM: (n, K) @ (N, K)^T → (n, N)
|
||||
result = nvfp4_blockscaled_gemm(
|
||||
x_sub_packed, x_sub_scales,
|
||||
w_packed, w_scales,
|
||||
) # (n, N) float32
|
||||
|
||||
# Weighted scatter-add
|
||||
weights_f32 = topk_weights[token_indices, k_idx].unsqueeze(-1)
|
||||
output[token_indices] += result * weights_f32
|
||||
|
||||
return output.to(torch.bfloat16)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience wrappers for uint32 packed scales
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def grouped_gemm_nvfp4_packed_sf(
|
||||
x_packed: torch.Tensor, # (num_tokens, K//2) int8
|
||||
x_sf_packed: torch.Tensor, # (num_tokens, sf_groups) uint32 packed UE4M3
|
||||
weights: torch.Tensor, # (E, N, K//2) int8
|
||||
weight_sf: torch.Tensor, # (E, N, sf_groups) uint32 packed UE4M3
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Grouped GEMM with uint32 packed UE4M3 scales."""
|
||||
from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_ue4m3_u32
|
||||
|
||||
x_sf_fp8 = unpack_ue4m3_u32(x_sf_packed) if x_sf_packed.dtype == torch.uint32 else x_sf_packed
|
||||
w_sf_fp8 = unpack_ue4m3_u32(weight_sf) if weight_sf.dtype == torch.uint32 else weight_sf
|
||||
|
||||
return grouped_gemm_nvfp4_native(
|
||||
x_packed, x_sf_fp8,
|
||||
weights, w_sf_fp8,
|
||||
topk_ids, topk_weights,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TileLang-based NVFP4 GEMM (using T.gemm with float8_e4m3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_tilelang_kernel_cache = {}
|
||||
|
||||
|
||||
def _make_tilelang_nvfp4_gemm(M, N, K_packed, block_M=128, block_N=128, block_K=64):
|
||||
"""Build a TileLang GEMM kernel using float8_e4m3 (f8f6f4) tensor cores.
|
||||
|
||||
This uses TileLang's T.gemm() with float8_e4m3 dtype, which lowers to
|
||||
tcgen05.mma.kind::f8f4 on Blackwell. Note: this path does NOT apply
|
||||
UE4M3 block scaling natively — it does E2M1 × E2M1 without scales.
|
||||
For proper NVFP4 with block scaling, use nvfp4_blockscaled_gemm().
|
||||
|
||||
The TileLang path is kept for experimentation and benchmarking.
|
||||
"""
|
||||
key = (M, N, K_packed, block_M, block_N, block_K)
|
||||
if key in _tilelang_kernel_cache:
|
||||
return _tilelang_kernel_cache[key]
|
||||
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
K = K_packed * 2 # Unpacked element count
|
||||
|
||||
@tilelang.jit(out_idx=[2])
|
||||
def fp4_gemm(
|
||||
A: T.Tensor((M, K_packed), "float8_e4m3"),
|
||||
B: T.Tensor((K_packed, N), "float8_e4m3"),
|
||||
C: T.Tensor((M, N), T.float32),
|
||||
):
|
||||
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
|
||||
A_shared = T.alloc_shared((block_M, block_K), "float8_e4m3")
|
||||
B_shared = T.alloc_shared((block_K, block_N), "float8_e4m3")
|
||||
C_local = T.alloc_fragment((block_M, block_N), T.float32)
|
||||
|
||||
T.clear(C_local)
|
||||
|
||||
for k in T.Pipelined(T.ceildiv(K_packed, block_K), num_stages=2):
|
||||
T.copy(A[by * block_M, k * block_K], A_shared)
|
||||
T.copy(B[k * block_K, bx * block_N], B_shared)
|
||||
T.gemm(A_shared, B_shared, C_local, num_elems_per_byte=2)
|
||||
|
||||
T.copy(C_local, C[by * block_M, bx * block_N])
|
||||
|
||||
_tilelang_kernel_cache[key] = fp4_gemm
|
||||
return fp4_gemm
|
||||
@@ -1,10 +1,10 @@
|
||||
"""
|
||||
NVFP4 Weight Transformation for TileLang mega_moe kernel.
|
||||
NVFP4 Weight Transformation for CUTLASS mega_moe kernel.
|
||||
|
||||
Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float32 global scale)
|
||||
into the TMA-aligned format expected by the block-scaled GEMM kernel:
|
||||
- Packed FP4 weights (uint8, K-major)
|
||||
- Packed UE4M3 scales (uint32, TMA-aligned UTCCP layout)
|
||||
into the format expected by the CUTLASS block-scaled GEMM kernel:
|
||||
- Packed FP4 weights (int8, K-major)
|
||||
- UE4M3 block scales (float8_e4m3fn, row-major — CUTLASS SF remap handles layout)
|
||||
|
||||
This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe.
|
||||
|
||||
@@ -70,13 +70,13 @@ def _interleave_l1_weights(weight: torch.Tensor) -> torch.Tensor:
|
||||
return interleaved.contiguous()
|
||||
|
||||
|
||||
def transform_nvfp4_weights_for_tilelang(
|
||||
def transform_nvfp4_weights_for_mega_moe(
|
||||
l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
|
||||
l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
|
||||
l1_weight_scale_2: torch.Tensor = None, # float32 global scale for L1
|
||||
l2_weight_scale_2: torch.Tensor = None, # float32 global scale for L2
|
||||
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Transform NVFP4 weights for the TileLang block-scaled GEMM.
|
||||
"""Transform NVFP4 weights for the CUTLASS block-scaled GEMM.
|
||||
|
||||
Matches the call signature from nightly vLLM deepseek_v4.py finalize_weights.
|
||||
|
||||
@@ -113,7 +113,3 @@ def transform_nvfp4_weights_for_tilelang(
|
||||
l2_weight_out = l2_weight.contiguous()
|
||||
|
||||
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)
|
||||
|
||||
|
||||
# Alias for drop-in replacement
|
||||
transform_nvfp4_weights_for_mega_moe = transform_nvfp4_weights_for_tilelang
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
"""
|
||||
NVFP4 quantization utilities — E2M1 packing and UE4M3 scale handling.
|
||||
Core math layer for the NVFP4 mega_moe kernel rewrite.
|
||||
"""
|
||||
|
||||
# E2M1 magnitude lookup table (positive values only)
|
||||
# Index 0-7 maps to: 0, 0.5, 1, 1.5, 2, 3, 4, 6
|
||||
def e2m1_magnitude(index: Int) -> Float64:
|
||||
if index == 0: return 0.0
|
||||
if index == 1: return 0.5
|
||||
if index == 2: return 1.0
|
||||
if index == 3: return 1.5
|
||||
if index == 4: return 2.0
|
||||
if index == 5: return 3.0
|
||||
if index == 6: return 4.0
|
||||
if index == 7: return 6.0
|
||||
return 0.0
|
||||
|
||||
|
||||
def quantize_e2m1(value: Float64) -> UInt8:
|
||||
"""Quantize a float64 value to E2M1 (4-bit), returning the 4-bit nibble with sign."""
|
||||
var sign = 0
|
||||
var abs_val = value
|
||||
if value < 0.0:
|
||||
sign = 1
|
||||
abs_val = -value
|
||||
|
||||
# Find best E2M1 match
|
||||
var best_idx = 0
|
||||
var best_err = abs_val # error for idx=0
|
||||
|
||||
for i in range(1, 8):
|
||||
mag = e2m1_magnitude(i)
|
||||
err = abs(abs_val - mag)
|
||||
if err < best_err:
|
||||
best_err = err
|
||||
best_idx = i
|
||||
|
||||
return (sign << 3) | best_idx
|
||||
|
||||
|
||||
def unpack_e2m1(packed: UInt8, idx: Int) -> Float64:
|
||||
"""Unpack one E2M1 value from a packed byte.
|
||||
idx=0 -> low nibble, idx=1 -> high nibble.
|
||||
"""
|
||||
nibble: UInt8
|
||||
if idx == 0:
|
||||
nibble = packed & 0x0F
|
||||
else:
|
||||
nibble = (packed >> 4) & 0x0F # keep sign bit
|
||||
|
||||
sign = (nibble >> 3) & 1
|
||||
mag_idx = nibble & 0x07
|
||||
magnitude = e2m1_magnitude(Int(mag_idx))
|
||||
|
||||
if sign:
|
||||
return -magnitude
|
||||
return magnitude
|
||||
|
||||
|
||||
def dequantize_nvfp4_weight(
|
||||
packed_weight: UInt8,
|
||||
block_scale: Float64,
|
||||
group_offset: Int,
|
||||
) -> Float64:
|
||||
"""Dequantize a single NVFP4 weight element.
|
||||
weight = E2M1_magnitude * block_scale
|
||||
(global_scale is already folded into block_scale)
|
||||
"""
|
||||
e2m1_value = unpack_e2m1(packed_weight, group_offset)
|
||||
return e2m1_value * block_scale
|
||||
|
||||
|
||||
def main() raises:
|
||||
# Test E2M1 quantization round-trip
|
||||
print("E2M1 quantization test:")
|
||||
for val in [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]:
|
||||
packed = quantize_e2m1(val)
|
||||
unpacked = unpack_e2m1(packed, 0)
|
||||
print(" ", val, " -> E2M1 -> ", unpacked)
|
||||
|
||||
# Test packed byte (two E2M1 values)
|
||||
print("\nPacked byte test:")
|
||||
lo = 1.5
|
||||
hi = -3.0
|
||||
packed = (quantize_e2m1(hi) << 4) | quantize_e2m1(lo)
|
||||
print(" lo=", lo, " hi=", hi, " packed=", packed)
|
||||
print(" unpack lo=", unpack_e2m1(packed, 0), " unpack hi=", unpack_e2m1(packed, 1))
|
||||
|
||||
# Test NVFP4 dequantization
|
||||
print("\nNVFP4 dequantization test:")
|
||||
packed_w = UInt8(0x36) # low=6.0, high=3.0
|
||||
scale = 0.5
|
||||
print(" packed=0x36, scale=0.5, lo=", dequantize_nvfp4_weight(packed_w, scale, 0))
|
||||
print(" packed=0x36, scale=0.5, hi=", dequantize_nvfp4_weight(packed_w, scale, 1))
|
||||
@@ -1,65 +0,0 @@
|
||||
"""
|
||||
NVFP4 Activation Quantization — BF16 → packed FP4 + UE4M3 block16 scales
|
||||
|
||||
Port of patches/staging_kernel.py to TileLang.
|
||||
|
||||
This quantizes BF16 hidden states to:
|
||||
- Packed FP4 (E2M1, 2 values per uint8 byte)
|
||||
- UE4M3 block16 scales (4 values per uint32 word, group_size=16)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import tilelang
|
||||
import tilelang.language as T
|
||||
|
||||
|
||||
@tilelang.jit
|
||||
def fp4_quant_kernel(
|
||||
X, # (M, K) bfloat16 input
|
||||
block_M=128,
|
||||
block_K=128,
|
||||
group_size=16, # NVFP4 UE4M3 group_size
|
||||
):
|
||||
"""Quantize BF16 → packed FP4 (E2M1) + UE4M3 block16 scales.
|
||||
|
||||
Output:
|
||||
- X_FP4: (M, K//2) uint8 packed E2M1
|
||||
- X_SF: (M, K//group_size//4) uint32 packed UE4M3 scales
|
||||
"""
|
||||
M, K = T.const("M, K")
|
||||
X: T.Tensor[[M, K], "bfloat16"]
|
||||
|
||||
X_FP4 = T.empty((M, K // 2), "uint8")
|
||||
X_SF = T.empty((M, K // (group_size * 4)), "uint32")
|
||||
|
||||
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=128) as (bx, by):
|
||||
# Each block quantizes a (block_M, block_K) tile
|
||||
tx = T.get_thread_binding()
|
||||
|
||||
row = bx * block_M + tx // (block_K // 2)
|
||||
col = tx % (block_K // 2)
|
||||
|
||||
if row < M and col < K // 2:
|
||||
# Load 2 BF16 values
|
||||
val_lo = X[row, col * 2]
|
||||
val_hi = X[row, col * 2 + 1]
|
||||
|
||||
# Quantize to E2M1 (4-bit each, pack 2 per byte)
|
||||
# ... (E2M1 quantization logic)
|
||||
packed = T.uint8(0) # placeholder
|
||||
X_FP4[row, col] = packed
|
||||
|
||||
# Compute block scale for group of 16 elements
|
||||
# ... (UE4M3 scale computation)
|
||||
|
||||
return X_FP4, X_SF
|
||||
|
||||
|
||||
def fp4_quant_reference(x_bf16, group_size=16):
|
||||
"""Reference FP4 quantization using PyTorch + DeepGEMM.
|
||||
|
||||
Used for correctness verification of the TileLang kernel.
|
||||
"""
|
||||
from deep_gemm import per_token_cast_to_fp4
|
||||
x_fp4, x_sf = per_token_cast_to_fp4(x_bf16)
|
||||
return x_fp4, x_sf
|
||||
@@ -1,147 +0,0 @@
|
||||
"""
|
||||
NVFP4 Weight Transformation for TileLang mega_moe kernel.
|
||||
|
||||
Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float32 global scale)
|
||||
into the TMA-aligned format expected by the block-scaled GEMM kernel:
|
||||
- Packed FP4 weights (uint8, K-major)
|
||||
- Packed UE4M3 scales (uint32, TMA-aligned UTCCP layout)
|
||||
|
||||
This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
def fold_global_scale(
|
||||
weight_scale: torch.Tensor, # (N, K//16) float8_e4m3fn
|
||||
weight_scale_2: torch.Tensor, # (num_logical,) or scalar float32
|
||||
logical_widths: list[int] = None, # per-logical-weight row counts
|
||||
) -> torch.Tensor:
|
||||
"""Fold global scale into block scales: UE4M3 * FP32 → UE4M3 → float32.
|
||||
|
||||
Returns: (N, K//16) float32 folded block scales.
|
||||
"""
|
||||
sf_f32 = weight_scale.to(torch.float32)
|
||||
|
||||
if weight_scale_2.numel() == 1:
|
||||
sf_f32 = sf_f32 * weight_scale_2.to(torch.float32)
|
||||
elif weight_scale_2.numel() > 1 and logical_widths is not None:
|
||||
# Per-logical-weight global scale (e.g., gate_up_proj has 2)
|
||||
expanded = []
|
||||
for i, w in enumerate(logical_widths):
|
||||
if i < len(weight_scale_2):
|
||||
expanded.append(weight_scale_2[i].flatten()[0].expand(w))
|
||||
global_scale = torch.cat(expanded).to(torch.float32).unsqueeze(1)
|
||||
sf_f32 = sf_f32 * global_scale
|
||||
else:
|
||||
sf_f32 = sf_f32 * weight_scale_2.max().to(torch.float32)
|
||||
|
||||
return sf_f32
|
||||
|
||||
|
||||
def pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack 4 UE4M3 (float8_e4m3fn) values into one uint32.
|
||||
|
||||
Input: (..., K//16) float8_e4m3fn
|
||||
Output: (..., K//64) uint32 (4 values packed per word)
|
||||
"""
|
||||
# View as raw bytes
|
||||
sf_u8 = sf.view(torch.uint8) # (..., K//16) uint8
|
||||
shape = sf_u8.shape
|
||||
assert shape[-1] % 4 == 0, f"Last dim {shape[-1]} not divisible by 4"
|
||||
|
||||
# Pack 4 consecutive uint8 values into one uint32
|
||||
packed = (sf_u8[..., 0::4].to(torch.int32) |
|
||||
(sf_u8[..., 1::4].to(torch.int32) << 8) |
|
||||
(sf_u8[..., 2::4].to(torch.int32) << 16) |
|
||||
(sf_u8[..., 3::4].to(torch.int32) << 24))
|
||||
|
||||
return packed.contiguous()
|
||||
|
||||
|
||||
def interleave_l1_weights(
|
||||
weight: torch.Tensor, # (E, 2*INTER, K//2) int8, K-major
|
||||
) -> torch.Tensor:
|
||||
"""Interleave L1 (gate_up) weights for 2CTA UMMA.
|
||||
|
||||
The gate and up projections are interleaved in groups of 8 rows
|
||||
to match the UMMA 2CTA schedule.
|
||||
"""
|
||||
E, N, K_half = weight.shape
|
||||
assert N % 16 == 0, f"N={N} not divisible by 16"
|
||||
|
||||
# Reshape to (E, N//16, 16, K_half) → interleave pairs of 8
|
||||
w = weight.view(E, N // 16, 16, K_half)
|
||||
# Interleave: [gate_0..7, up_0..7, gate_8..15, up_8..15, ...]
|
||||
# Actually: pairs of 8 rows from gate and up
|
||||
w_gate = w[:, :, :8, :]
|
||||
w_up = w[:, :, 8:16, :]
|
||||
interleaved = torch.stack([w_gate, w_up], dim=3).reshape(E, N, K_half)
|
||||
|
||||
return interleaved.contiguous()
|
||||
|
||||
|
||||
def transform_nvfp4_weights_for_tilelang(
|
||||
weight: torch.Tensor, # (E, N, K//2) int8, K-major packed FP4
|
||||
weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn UE4M3
|
||||
weight_scale_2: torch.Tensor, # (E, 1) or (E, 2) float32 global scale
|
||||
N: int, # output dimension (2*INTER for L1, HIDDEN for L2)
|
||||
K: int, # input dimension (HIDDEN for L1, INTER for L2)
|
||||
gran_k: int = 16, # NVFP4 group_size
|
||||
is_l1: bool = False, # True for gate_up_proj (needs interleaving)
|
||||
logical_widths: list[int] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Transform NVFP4 weights for the TileLang block-scaled GEMM.
|
||||
|
||||
Returns:
|
||||
transformed_weight: (E, N, K//2) int8, K-major, contiguously interleaved
|
||||
transformed_sf: (E, N, K//64) uint32, packed UE4M3 with folded global scale
|
||||
"""
|
||||
device = weight.device
|
||||
|
||||
# Step 1: Fold global scale into block scales
|
||||
sf_folded = fold_global_scale(weight_scale, weight_scale_2, logical_widths)
|
||||
|
||||
# Step 2: Clamp and convert back to UE4M3
|
||||
sf_clamped = sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
|
||||
# Step 3: Pack UE4M3 into uint32
|
||||
sf_packed = pack_ue4m3_to_uint32(sf_clamped) # (E, N, K//64)
|
||||
|
||||
# Step 4: Interleave L1 weights if needed
|
||||
if is_l1:
|
||||
transformed_weight = interleave_l1_weights(weight)
|
||||
else:
|
||||
transformed_weight = weight.contiguous()
|
||||
|
||||
# Step 5: Ensure K-major layout (already K-major from checkpoint)
|
||||
# The weight should be (E, N, K//2) with stride (N*K//2, K//2, 1) for K-major
|
||||
|
||||
return transformed_weight, sf_packed
|
||||
|
||||
|
||||
def main():
|
||||
"""Test the weight transformation with dummy data."""
|
||||
E = 4 # small test
|
||||
N = 64
|
||||
K = 128
|
||||
|
||||
device = "cpu" # No GPU on sandbox
|
||||
weight = torch.randint(-128, 127, (E, N, K // 2), dtype=torch.int8, device=device)
|
||||
weight_scale = torch.randn(E, N, K // 16, device=device).abs().to(torch.float8_e4m3fn)
|
||||
weight_scale_2 = torch.ones(E, 2, dtype=torch.float32, device=device)
|
||||
|
||||
transformed_w, transformed_sf = transform_nvfp4_weights_for_tilelang(
|
||||
weight, weight_scale, weight_scale_2, N, K, is_l1=True,
|
||||
logical_widths=[32, 32],
|
||||
)
|
||||
|
||||
print(f"Weight: {transformed_w.shape} {transformed_w.dtype}")
|
||||
print(f"SF: {transformed_sf.shape} {transformed_sf.dtype}")
|
||||
print(f"Expected SF: (E={E}, N={N}, K//64={K//64})")
|
||||
print("Weight transformation test passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user