diff --git a/src/layout.mojo b/src/layout.mojo deleted file mode 100644 index af284d20..00000000 --- a/src/layout.mojo +++ /dev/null @@ -1,32 +0,0 @@ -""" -NVFP4 weight transformation and SF layout utilities. - -Port of deep_gemm.mega.transform_nvfp4_weights_for_mega_moe -""" - -from math import ceil_div - -fn fold_global_scale_into_block_scales( - weight_scale: Tensor[float8_e4m3fn], # (N, K//16) UE4M3 block scales - weight_scale_2: Tensor[float32], # (num_logical,) or scalar global scale - logical_widths: List[int], # per-logical-weight row counts -) -> Tensor[float32]: - """Fold global scale into block scales: UE4M3 * FP32 -> FP32""" - # Convert UE4M3 to float32, multiply by global scale - # For MergedColumnParallelLinear, expand per-logical global scale - ... - -fn pack_ue4m3_to_int32(sf: Tensor[float8_e4m3fn]) -> Tensor[int32]: - """Pack 4 UE4M3 values (4 bytes) into one int32 for DeepGEMM TMA""" - # View as uint8, pack 4 consecutive bytes into int32 - ... - -fn transform_sf_into_required_layout( - sf_mn: Tensor[int32], # MN-major packed SF - N: int, K: int, - recipe: Tuple[int, int], # (gran_mn, gran_k) - num_groups: int, -) -> Tensor[int32]: - """Transform SF into TMA-aligned UTCCP layout for DeepGEMM""" - # Call into DeepGEMM's C++ layout transform - ... diff --git a/src/mega_moe.mojo b/src/mega_moe.mojo deleted file mode 100644 index d9056f5d..00000000 --- a/src/mega_moe.mojo +++ /dev/null @@ -1,24 +0,0 @@ -""" -NVFP4 Mega MoE Kernel — Mojo Rewrite - -This is a from-scratch rewrite of the DeepGEMM fp8_nvfp4_mega_moe kernel. -The CUDA C++ version crashes on B200 with CUDA_ERROR_LAUNCH_FAILED in the -SM100_MMA_MXF4NVF4 instruction. This Mojo rewrite aims to: -1. Produce a correct, working NVFP4 mega_moe kernel -2. Be more maintainable than 1200+ lines of CUDA template metaprogramming -3. Leverage Mojo's GPU programming model for cleaner TMA/UMMA integration - -NVFP4 format: -- Weights: E2M1 packed (2 x 4-bit values per byte), int8 container -- Block scales: UE4M3 (float8_e4m3fn), group_size=16 -- Global scale: float32 scalar -- Activation: FP8 e4m3fn with UE8M0 per-token scales - -Key differences from MXFP4 (group_size=32, UE8M0): -- 2 SF K-columns per BLOCK_K (128/16/4=2) instead of 1 -- mxf4nvf4 UMMA instruction with scale_vec::4X -- Different weight transformation (gran_k=16 vs gran_k=32) -""" - -# TODO: Implement as Mojo GPU kernel -# Waiting for Mojo GPU programming docs / SDK setup diff --git a/src/nvfp4_blockscaled_gemm.py b/src/nvfp4_blockscaled_gemm.py deleted file mode 100644 index 135c65bc..00000000 --- a/src/nvfp4_blockscaled_gemm.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -NVFP4 Block-Scaled GEMM — TileLang -2CTA persistent kernel for Blackwell SM100 - -Adapted from tilelang/examples/blockscaled_gemm_sm100/gemm_mxfp8_blockscaled_1d1d.py - -Key NVFP4 differences from MXFP8: -- sf_granularity_k=16 (UE4M3 block16) instead of 128 (UE8M0) -- 2 SF uint32 words per BLOCK_K instead of 1 (sf_load_period=1, not 4) -- sf_a_id/sf_b_id cycle through 0,1,2,3 within each SF word -- Must call tcgen05_gemm_blockscaled twice per K-block with different SF words -- in_dtype="float4_e2m1fnx2" (packed FP4) instead of "float8_e4m3fn" (FP8) -""" - -import torch -import tilelang -import tilelang.language as T -from tilelang.carver.arch import driver -from tilelang.profiler import do_bench - -# NVFP4 constants -NVFP4_SF_GRAN_K = 16 # UE4M3 group_size -NVFP4_SF_PACK = 4 # 4 UE4M3 values per uint32 -NVFP4_SF_K_PER_WORD = NVFP4_SF_GRAN_K * NVFP4_SF_PACK # 64 K-elements per uint32 word -# For BLOCK_K=128: 128/64 = 2 SF words per K-block -# Each word covers 4 UMMA sub-atoms (sf_id 0-3) - - -@tilelang.jit -def nvfp4_blockscaled_gemm_2cta_persistent( - A, # (M, K) packed FP4 - B, # (N, K) packed FP4 (K-major for NN) - SFA, # (sf_k_groups * M) uint32 packed UE4M3 - SFB, # (sf_k_groups * N) uint32 packed UE4M3 - block_M=128, - block_N=256, - block_K=128, - in_dtype="float4_e2m1fnx2", - out_dtype="bfloat16", - accum_dtype="float32", - num_stages=2, - sf_granularity_k=NVFP4_SF_GRAN_K, - use_tma_store=True, - store_block_N=64, -): - M, N, K = T.const("M, N, K") - - assert block_M == 128 - assert block_N == 256 - assert block_K == 128 - - half_N = block_N // 2 - k_iters = T.ceildiv(K, block_K) - - # NVFP4 SF layout: - # sf_granularity_k=16, pack_factor=4 → 64 K-elements per uint32 - # BLOCK_K=128 → 2 SF words per K-block - # sf_load_period=1 (load every K-block, unlike MXFP8 which loads every 4) - sf_words_per_k_block = block_K // (sf_granularity_k * NVFP4_SF_PACK) # 2 - sf_k_groups = T.ceildiv(K, sf_granularity_k * NVFP4_SF_PACK) # total SF groups across K - - A: T.Tensor[[M, K], in_dtype] - B: T.Tensor[[N, K], in_dtype] - SFA: T.Tensor[[sf_k_groups * M], T.uint32] - SFB: T.Tensor[[sf_k_groups * N], T.uint32] - C = T.empty((M, N), out_dtype) - - sm_num = driver.get_num_sms() - num_clusters = sm_num // 2 - m_blocks = T.ceildiv(M, block_M) - m_clusters = m_blocks // 2 - n_blocks = T.ceildiv(N, block_N) - waves = T.ceildiv(m_blocks * n_blocks, sm_num) - group_size = 16 - assert n_blocks % (2 * group_size) == 0 - - with T.Kernel(sm_num, threads=256, cluster_dims=2) as (block_id): - cta_id = T.block_rank_in_cluster() - T.assume(cta_id < 2) - - # Shared memory — FP4 packed (K is halved because 2 values per byte) - A_shared = T.alloc_shared((num_stages, block_M, block_K // 2), "uint8") - B_shared = T.alloc_shared((num_stages, block_K // 2, half_N), "uint8") - - # Shared memory for SF — 2 uint32 words per K-block for NVFP4 - SFA_shared = T.alloc_shared((num_stages, block_M, sf_words_per_k_block), T.uint32) - SFB_shared = T.alloc_shared((num_stages, block_N, sf_words_per_k_block), T.uint32) - - # Tensor memory - C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) - SFA_tmem = T.alloc_tmem([block_M, block_M // 128 * 4], T.uint32) - SFB_tmem = T.alloc_tmem([block_M, block_N // 128 * 4], T.uint32) - - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) - C_shared = T.alloc_shared((block_M, store_block_N), out_dtype) - - loaded = T.alloc_barrier([32] * num_stages) - with_sf_full = T.alloc_cluster_barrier([32 * 2] * num_stages) - consumed = T.alloc_cluster_barrier([1] * num_stages) - tmem_full = T.alloc_cluster_barrier([1]) - tmem_empty = T.alloc_cluster_barrier([128 * 2]) - - tx = T.get_thread_binding() - warp_idx = tx // 32 - - if warp_idx == 0: - # Warp 0: TMA load - for w in T.serial(waves): - cluster_id = block_id // 2 - tile_id = num_clusters * w + cluster_id - bx_cluster = (tile_id // group_size) % m_clusters - bx = bx_cluster * 2 + cta_id - by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size - - if bx * block_M < M and by * block_N < N: - for k in T.serial(k_iters): - phase = w * k_iters + k - stage = phase % num_stages - parity = (phase // num_stages) & 1 - - T.mbarrier_wait_parity(consumed[stage], parity ^ 1) - - # TMA load A (packed FP4) - T.tma_copy( - A[bx * block_M, k * block_K], - A_shared[stage, :, :], - barrier=loaded[stage], - ) - # TMA load B (packed FP4) - T.tma_copy( - B[k * block_K, by * block_N + cta_id * half_N], - B_shared[stage, :, :], - barrier=loaded[stage], - ) - # TMA load SFA — 2 words per K-block for NVFP4 - # Word 0 covers K[k*128 : k*128+64] - # Word 1 covers K[k*128+64 : k*128+128] - for sf_word in T.unroll(sf_words_per_k_block): - sf_group = k * sf_words_per_k_block + sf_word - T.tma_copy( - SFA[sf_group * M + bx * block_M : - sf_group * M + (bx + 1) * block_M], - SFA_shared[stage, :, sf_word], - barrier=loaded[stage], - ) - # TMA load SFB — 2 words per K-block - for sf_word in T.unroll(sf_words_per_k_block): - sf_group = k * sf_words_per_k_block + sf_word - T.tma_copy( - SFB[sf_group * N + by * block_N : - sf_group * N + (by + 1) * block_N], - SFB_shared[stage, :, sf_word], - barrier=loaded[stage], - ) - T.mbarrier_arrive(loaded[stage]) - - elif warp_idx == 1 and cta_id == 0: - # Warp 1: MMA issue + UTCCP - for w in T.serial(waves): - cluster_id = block_id // 2 - tile_id = num_clusters * w + cluster_id - bx_cluster = (tile_id // group_size) % m_clusters - bx = bx_cluster * 2 + cta_id - by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size - - if bx * block_M < M and by * block_N < N: - for k in T.serial(k_iters): - phase = w * k_iters + k - stage = phase % num_stages - parity = (phase // num_stages) & 1 - - T.mbarrier_wait_parity(with_sf_full[stage], parity) - - # NVFP4: 2 SF words per K-block → 2 sub-GEMM calls - # Each sub-GEMM covers 64 K-elements (BLOCK_K/2) - for sf_word in T.unroll(sf_words_per_k_block): - # UTCCP: copy SF word from shared to tensor memory - T.tcgen05_cp_warpx4( - SFA_shared[stage, :, sf_word], - SFA_tmem, - use_2cta=True, - ) - T.tcgen05_cp_warpx4( - SFB_shared[stage, :, sf_word], - SFB_tmem, - use_2cta=True, - ) - - # Block-scaled GEMM - # For NVFP4: sf_a_id selects which of 4 packed UE4M3 - # values in the uint32 word to use. - # With sf_granularity_k=16 and UMMA_K=64: - # 4 SF values per UMMA atom, sf_id=0 covers all 4 - # (the hardware auto-cycles through 0,1,2,3 internally) - T.tcgen05_gemm_blockscaled( - A_shared[stage, :, :], - B_shared[stage, :, :], - C_tmem, - SFA_tmem, - SFB_tmem, - transpose_B=True, - mbar=consumed[stage], - clear_accum=(k == 0 and sf_word == 0), - sf_a_id=0, - sf_b_id=0, - use_2cta=True, - ) - - T.tcgen05_mma_arrive(tmem_full, arrive_2cta=True) - - elif warp_idx == 2: - # Warp 2: SF transpose - for w in T.serial(waves): - cluster_id = block_id // 2 - tile_id = num_clusters * w + cluster_id - bx_cluster = (tile_id // group_size) % m_clusters - bx = bx_cluster * 2 + cta_id - by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size - - if bx * block_M < M and by * block_N < N: - for k in T.serial(k_iters): - phase = w * k_iters + k - stage = phase % num_stages - parity = (phase // num_stages) & 1 - - T.mbarrier_wait_parity(loaded[stage], parity) - # Transpose all SF words for this K-block - for sf_word in T.unroll(sf_words_per_k_block): - T.tcgen05_sf_warp_transpose(SFA_shared[stage, :, sf_word]) - T.tcgen05_sf_warp_transpose(SFB_shared[stage, :, sf_word]) - T.fence_proxy_async() - T.mbarrier_arrive(with_sf_full[stage], 0) - - # Epilogue: write C from tmem to global memory - for w in T.serial(waves): - cluster_id = block_id // 2 - tile_id = num_clusters * w + cluster_id - bx_cluster = (tile_id // group_size) % m_clusters - bx = bx_cluster * 2 + cta_id - by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size - - if bx * block_M < M and by * block_N < N: - T.mbarrier_wait_parity(tmem_full, w & 1) - T.copy(C_tmem, C_local) - T.copy(C_local, C_local_cast) - T.copy(C_local_cast, C_shared) - - if use_tma_store: - T.copy(C_shared, C[bx * block_M, by * block_N]) - else: - T.copy(C_shared, C[bx * block_M, by * block_N], disable_tma=True) - - T.mbarrier_arrive(tmem_empty, 0) - - return C diff --git a/src/nvfp4_mega_moe.py b/src/nvfp4_mega_moe.py deleted file mode 100644 index ee0e26f1..00000000 --- a/src/nvfp4_mega_moe.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -NVFP4 Mega MoE Kernel — Full MoE with expert parallelism. - -This is the main kernel that replaces fp8_nvfp4_mega_moe from DeepGEMM. - -Architecture: -- L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with UE4M3 scales) -- SiLU+Mul activation -- L2 GEMM: down_proj (FP4 x FP4 → BF16 with UE4M3 scales) -- NVLink cross-rank sync via symm buffer -- Expert parallel: each rank handles NUM_EXPERTS/8 experts - -The kernel is written in TileLang, compiled to SM100 (Blackwell) CUBIN. -""" - -import torch -import tilelang -import tilelang.language as T -from tilelang.carver.arch import driver - -# DeepSeek-V4-Pro dimensions -HIDDEN = 7168 -INTERMEDIATE = 3072 -NUM_EXPERTS = 256 -NUM_RANKS = 8 -NUM_TOPK = 6 - -# Block sizes for the block-scaled GEMM -BLOCK_M = 192 # tokens per tile -BLOCK_K = 128 -BLOCK_N = 128 - -# NVFP4 scale parameters -SF_GRANULARITY_K = 16 # UE4M3 group_size -SF_PACK_FACTOR = 4 # 4 UE4M3 values per uint32 -SF_WORDS_PER_K_BLOCK = BLOCK_K // (SF_GRANULARITY_K * SF_PACK_FACTOR) # 2 - - -@tilelang.jit -def nvfp4_mega_moe_l1( - # Activation (packed FP4, from staging kernel) - X, # (num_tokens, K//2) int8 packed E2M1 - X_SF, # (num_tokens, sf_k_groups) uint32 packed UE4M3 - # L1 weights (pre-transformed for mega_moe) - L1_W, # (num_experts_per_rank, 2*INTERMEDIATE, K//2) int8 K-major - L1_SF, # (num_experts_per_rank, 2*INTERMEDIATE, sf_k_groups) uint32 TMA-aligned - # Routing - TopkIdx, # (num_tokens, NUM_TOPK) int32 - TopkW, # (num_tokens, NUM_TOPK) float32 - # Output - Y, # (num_tokens, 2*INTERMEDIATE) bfloat16 - num_experts_per_rank, -): - """ - L1 GEMM for mega_moe: gate_up_proj - X(FP4) @ L1_W(FP4)^T → Y(BF16) with UE4M3 block scaling - - This handles the gate_up_proj for all experts in the rank. - Each token is routed to NUM_TOPK experts, and the GEMM is computed - per expert group using persistent scheduling. - """ - num_tokens = T.const("num_tokens") - K = T.const("K") - INTER = T.const("INTERMEDIATE") - - k_iters = T.ceildiv(K, BLOCK_K) - sf_k_groups = T.ceildiv(K, SF_GRANULARITY_K * SF_PACK_FACTOR) - - with T.Kernel( - T.ceildiv(num_tokens, BLOCK_M), - T.ceildiv(2 * INTER, BLOCK_N), - threads=256, - cluster_dims=2, - ) as (bx, by): - cta_id = T.block_rank_in_cluster() - T.assume(cta_id < 2) - - # ... (persistent scheduling, expert routing, L1 GEMM with block scaling) - # This follows the same pattern as nvfp4_blockscaled_gemm_2cta_persistent - # but adds expert scheduling and topk weight scaling - - pass # Full implementation in progress - - -@tilelang.jit -def nvfp4_mega_moe_l2( - # L1 output (quantized to FP4 for L2 input) - X, # (num_tokens, INTER//2) int8 packed E2M1 - X_SF, # (num_tokens, sf_k_groups) uint32 packed UE4M3 - # L2 weights - L2_W, # (num_experts_per_rank, HIDDEN, INTER//2) int8 K-major - L2_SF, # (num_experts_per_rank, HIDDEN, sf_k_groups) uint32 TMA-aligned - # Routing - TopkIdx, # (num_tokens, NUM_TOPK) int32 - TopkW, # (num_tokens, NUM_TOPK) float32 - # Output - Y, # (num_tokens, HIDDEN) bfloat16 - num_experts_per_rank, -): - """ - L2 GEMM for mega_moe: down_proj - X(FP4) @ L2_W(FP4)^T → Y(BF16) with UE4M3 block scaling - - After SiLU+Mul on the L1 output, the result is quantized to FP4 - and fed into L2. - """ - pass # Symmetric to L1 - - -def nvfp4_mega_moe_full( - hidden_states, # (num_tokens, HIDDEN) bfloat16 - topk_weights, # (num_tokens, NUM_TOPK) float32 - topk_ids, # (num_tokens, NUM_TOPK) int32 - l1_weights, # L1 weights (transformed for mega_moe) - l1_scales, # L1 UE4M3 scales (transformed) - l2_weights, # L2 weights (transformed for mega_moe) - l2_scales, # L2 UE4M3 scales (transformed) - symm_buffer, # NVLink symm buffer for cross-rank sync -): - """ - Full mega_moe forward pass: - 1. Stage: quantize BF16 hidden_states → FP4 + UE4M3 scales - 2. L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with block scaling) - 3. SiLU + Mul (activation) - 4. Quantize L1 output → FP4 + UE4M3 scales - 5. L2 GEMM: down_proj (FP4 x FP4 → BF16 with block scaling) - 6. NVLink sync + reduce across ranks - """ - # Step 1: Stage activation (BF16 → FP4 quantization) - # This is the staging kernel from patches/staging_kernel.py - x_fp4, x_sf = stage_activation(hidden_states) - - # Step 2: L1 GEMM - l1_output = nvfp4_mega_moe_l1( - x_fp4, x_sf, l1_weights, l1_scales, topk_ids, topk_weights) - - # Step 3: SiLU + Mul - gate, up = l1_output.chunk(2, dim=-1) - activated = torch.nn.functional.silu(gate) * up - - # Step 4: Quantize L1 output → FP4 - l1_fp4, l1_sf = stage_activation(activated) - - # Step 5: L2 GEMM - l2_output = nvfp4_mega_moe_l2( - l1_fp4, l1_sf, l2_weights, l2_scales, topk_ids, topk_weights) - - # Step 6: NVLink reduce - output = nvlink_reduce(l2_output, topk_weights, symm_buffer) - - return output - - -def stage_activation(x_bf16): - """Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales. - - This replaces the Triton staging kernel from patches/staging_kernel.py. - """ - # E2M1 quantization with UE4M3 block scaling - # For now, use PyTorch reference implementation - # TODO: Write as TileLang kernel for full pipeline integration - from deep_gemm import per_token_cast_to_fp4 - return per_token_cast_to_fp4(x_bf16) diff --git a/src/nvfp4_megamoe_kernel/__init__.py b/src/nvfp4_megamoe_kernel/__init__.py index c592bdb5..05cdc12d 100644 --- a/src/nvfp4_megamoe_kernel/__init__.py +++ b/src/nvfp4_megamoe_kernel/__init__.py @@ -1,4 +1,4 @@ -"""NVFP4 Mega MoE Kernel — TileLang implementation for DeepSeek-V4-Pro on Blackwell.""" +"""NVFP4 Mega MoE Kernel — CUTLASS implementation for DeepSeek-V4-Pro on Blackwell.""" from nvfp4_megamoe_kernel.nvfp4_mega_moe import ( nvfp4_mega_moe_full, @@ -7,7 +7,7 @@ from nvfp4_megamoe_kernel.nvfp4_mega_moe import ( stage_activation, ) from nvfp4_megamoe_kernel.weight_transform import ( - transform_nvfp4_weights_for_tilelang as transform_nvfp4_weights_for_mega_moe, + transform_nvfp4_weights_for_mega_moe, ) from nvfp4_megamoe_kernel.symm_buffer import ( SymmBuffer, diff --git a/src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 71a6b920..00000000 Binary files a/src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc and /dev/null differ diff --git a/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_dequant.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_dequant.cpython-312.pyc deleted file mode 100644 index f38c12b5..00000000 Binary files a/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_dequant.cpython-312.pyc and /dev/null differ diff --git a/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_mega_moe.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_mega_moe.cpython-312.pyc deleted file mode 100644 index 0ad6c213..00000000 Binary files a/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_mega_moe.cpython-312.pyc and /dev/null differ diff --git a/src/nvfp4_megamoe_kernel/__pycache__/symm_buffer.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/symm_buffer.cpython-312.pyc deleted file mode 100644 index afea703e..00000000 Binary files a/src/nvfp4_megamoe_kernel/__pycache__/symm_buffer.cpython-312.pyc and /dev/null differ diff --git a/src/nvfp4_megamoe_kernel/__pycache__/tilelang_kernels.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/tilelang_kernels.cpython-312.pyc deleted file mode 100644 index bc6cdc77..00000000 Binary files a/src/nvfp4_megamoe_kernel/__pycache__/tilelang_kernels.cpython-312.pyc and /dev/null differ diff --git a/src/nvfp4_megamoe_kernel/__pycache__/weight_transform.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/weight_transform.cpython-312.pyc deleted file mode 100644 index 4ea5d0e8..00000000 Binary files a/src/nvfp4_megamoe_kernel/__pycache__/weight_transform.cpython-312.pyc and /dev/null differ diff --git a/src/nvfp4_megamoe_kernel/nvfp4_dequant.py b/src/nvfp4_megamoe_kernel/nvfp4_dequant.py deleted file mode 100644 index d19d33ad..00000000 --- a/src/nvfp4_megamoe_kernel/nvfp4_dequant.py +++ /dev/null @@ -1,71 +0,0 @@ -""" -NVFP4 dequantization utilities. - -Converts packed E2M1 (int8) + UE4M3 block16 scales to BF16. -""" - -import torch - - -def unpack_ue4m3_u32(packed: torch.Tensor) -> torch.Tensor: - """Unpack uint32 packed UE4M3 (4 values per uint32) to float8_e4m3fn. - - Args: - packed: (..., sf_k_groups) uint32 — 4 UE4M3 values packed per element - - Returns: - (..., sf_k_groups * 4) float8_e4m3fn - """ - u32 = packed.to(torch.int32) - b0 = (u32 & 0xFF).to(torch.uint8) - b1 = ((u32 >> 8) & 0xFF).to(torch.uint8) - b2 = ((u32 >> 16) & 0xFF).to(torch.uint8) - b3 = ((u32 >> 24) & 0xFF).to(torch.uint8) - interleaved = torch.stack([b0, b1, b2, b3], dim=-1) - return interleaved.reshape(*packed.shape[:-1], -1).contiguous().view(torch.float8_e4m3fn) - - -def unpack_e2m1_to_bf16( - packed: torch.Tensor, # (..., K//2) int8 — two E2M1 values per byte - scales: torch.Tensor, # (..., K//16) float8_e4m3fn — UE4M3 block16 scales -) -> torch.Tensor: - """Dequantize packed E2M1 with UE4M3 block16 scales to BF16. - - E2M1 format: sign(1) exponent(2) mantissa(1), bias=2 - Each int8 byte contains 2 E2M1 values: low nibble=element 0, high nibble=element 1. - UE4M3 block scales: one float8_e4m3fn scale per group of 16 consecutive elements. - - Args: - packed: (..., K//2) int8 packed E2M1 - scales: (..., K//16) float8_e4m3fn UE4M3 block16 scales - - Returns: - (..., K) bfloat16 - """ - u8 = packed.view(torch.uint8) - lo = (u8 & 0x0F).to(torch.int32) # lower nibble - hi = (u8 >> 4).to(torch.int32) # upper nibble - - # Interleave: (..., K//2, 2) → (..., K) - unpacked = torch.stack([lo, hi], dim=-1).reshape(*u8.shape[:-1], -1) - - # E2M1 → float32 - sign = (unpacked >> 3).to(torch.float32) * -2.0 + 1.0 - exp_field = (unpacked >> 1) & 0x3 - mant = (unpacked & 0x1).to(torch.float32) - - # E2M1 value = sign * 2^(exp - 2) * (1 + mant * 0.5) - val = sign * (2.0 ** (exp_field.to(torch.float32) - 2.0)) * (1.0 + mant * 0.5) - - # Zero: exp=0 and mant=0 - zero_mask = (exp_field == 0) & ((unpacked & 1) == 0) - val = val * (~zero_mask).to(torch.float32) - - # Apply UE4M3 block16 scales - sf_f32 = scales.to(torch.float32) - sf_expanded = sf_f32.repeat_interleave(16, dim=-1) - - K = unpacked.shape[-1] - sf_expanded = sf_expanded[..., :K] - - return (val * sf_expanded).to(torch.bfloat16) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 3990d409..881ac113 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -26,12 +26,24 @@ avoiding the costly dequantization step. import os import torch -from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32 -from nvfp4_megamoe_kernel.tilelang_nvfp4_gemm import ( - nvfp4_blockscaled_gemm, - grouped_gemm_nvfp4_native, - grouped_gemm_nvfp4_packed_sf, -) +def unpack_ue4m3_u32(x_u32): + """Unpack uint32 packed UE4M3 scales to float8_e4m3fn. + + Each uint32 contains 4 UE4M3 values packed in bits [0:8], [8:16], [16:24], [24:32]. + """ + x_u32 = x_u32.contiguous() + M, N = x_u32.shape + out = torch.empty(M, N * 4, dtype=torch.float8_e4m3fn, device=x_u32.device) + # Vectorized unpack: extract 4 bytes from each uint32 + b0 = (x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) + b1 = ((x_u32 >> 8) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) + b2 = ((x_u32 >> 16) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) + b3 = ((x_u32 >> 24) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) + out[:, 0::4] = b0 + out[:, 1::4] = b1 + out[:, 2::4] = b2 + out[:, 3::4] = b3 + return out # CUTLASS native NVFP4 block-scaled GEMM (SM100 Blackwell) # Primary path: uses CUTLASS MainloopSm100TmaUmmaWarpSpecializedBlockScaled @@ -97,21 +109,11 @@ def nvfp4_mega_moe_l1( x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales - # Use CUTLASS native block-scaled GEMM if available - if MEGA_MOE_USE_CUTLASS and _CUTLASS_AVAILABLE: - output = cutlass_grouped_nvfp4_gemm( - x_fp4, x_sf_fp8, - l1_weights, w_sf_fp8, - topk_ids, topk_weights, - ) - else: - # Fallback to TileLang path - output = grouped_gemm_nvfp4_native( - x_fp4, x_sf_fp8, - l1_weights, w_sf_fp8, - topk_ids, topk_weights, - ) - + output = cutlass_grouped_nvfp4_gemm( + x_fp4, x_sf_fp8, + l1_weights, w_sf_fp8, + topk_ids, topk_weights, + ) return output # (num_tokens, 6144) bfloat16 @@ -141,21 +143,11 @@ def nvfp4_mega_moe_l2( x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales - # Use CUTLASS native block-scaled GEMM if available - if MEGA_MOE_USE_CUTLASS and _CUTLASS_AVAILABLE: - output = cutlass_grouped_nvfp4_gemm( - x_fp4, x_sf_fp8, - l2_weights, w_sf_fp8, - topk_ids, topk_weights, - ) - else: - # Fallback to TileLang path - output = grouped_gemm_nvfp4_native( - x_fp4, x_sf_fp8, - l2_weights, w_sf_fp8, - topk_ids, topk_weights, - ) - + output = cutlass_grouped_nvfp4_gemm( + x_fp4, x_sf_fp8, + l2_weights, w_sf_fp8, + topk_ids, topk_weights, + ) return output # (num_tokens, 7168) bfloat16 diff --git a/src/nvfp4_megamoe_kernel/tilelang_kernels.py b/src/nvfp4_megamoe_kernel/tilelang_kernels.py deleted file mode 100644 index 1ab67e19..00000000 --- a/src/nvfp4_megamoe_kernel/tilelang_kernels.py +++ /dev/null @@ -1,136 +0,0 @@ -""" -TileLang NVFP4 Mega MoE Kernels — BF16 GEMM with FP4 dequantization. - -This module provides the core GEMM kernels for the DeepSeek-V4-Pro MoE layer: -- L1 (gate_up_proj): HIDDEN→2*INTERMEDIATE, FP4 weights + UE4M3 scales -- L2 (down_proj): INTERMEDIATE→HIDDEN, FP4 weights + UE4M3 scales - -Current approach: Dequantize FP4→BF16, then run BF16 GEMM via TileLang. -This is correct and functional. Once TileLang adds native tcgen05.mma -kind::mxf8f6f4.block_scale support for E2M1+UE4M3, we'll switch to -native FP4 block-scaled MMA for maximum throughput. - -The per-expert GEMM uses a "segmented" approach: sort tokens by expert, -batched GEMM per expert using TileLang-compiled BF16 kernels. -""" - -import torch -import tilelang -import tilelang.language as T - -from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32 - -# --------------------------------------------------------------------------- -# TileLang BF16 GEMM kernel (auto-detects Blackwell, lowers to tcgen05) -# --------------------------------------------------------------------------- - -_kernel_cache = {} - - -def _make_bf16_gemm(M, N, K, block_M=128, block_N=128, block_K=128, num_stages=3): - """Build and cache a TileLang BF16 GEMM kernel for the given dimensions.""" - key = (M, N, K, block_M, block_N, block_K, num_stages) - if key in _kernel_cache: - return _kernel_cache[key] - - @tilelang.jit(out_idx=[2]) - def bf16_gemm( - A: T.Tensor((M, K), T.bfloat16), - B: T.Tensor((K, N), T.bfloat16), - C: T.Tensor((M, N), T.bfloat16), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) - B_shared = T.alloc_shared((block_K, block_N), T.bfloat16) - C_local = T.alloc_fragment((block_M, block_N), T.float32) - - T.clear(C_local) - - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local) - - T.copy(C_local, C[by * block_M, bx * block_N]) - - _kernel_cache[key] = bf16_gemm - return bf16_gemm - - -# --------------------------------------------------------------------------- -# Grouped expert GEMM with FP4 dequantization -# --------------------------------------------------------------------------- - -def grouped_gemm_fp4( - x_bf16: torch.Tensor, # (total_tokens, K_dim) bfloat16 - weights_fp4: torch.Tensor, # (E, N, K//2) int8 packed E2M1 - scales_ue4m3: torch.Tensor, # (E, N, K//16) float8_e4m3fn - topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32 - topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32 -) -> torch.Tensor: - """Segmented grouped expert GEMM: dequantize FP4→BF16, per-expert GEMM. - - Strategy: - 1. Sort tokens by expert assignment - 2. For each expert, dequantize its weight to BF16 (cached) - 3. Run batched BF16 GEMM using TileLang-compiled kernels - 4. Scatter results back with routing weights - """ - num_tokens, K_dim = x_bf16.shape - E, N, K_half = weights_fp4.shape - K = K_half * 2 - assert K == K_dim, f"Activation K={K_dim} doesn't match weight K={K}" - top_k = topk_ids.shape[1] - device = x_bf16.device - - output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=device) - - # Pre-compute expert weight dequantization (cache for repeated use) - # For 32 experts, this is manageable - w_bf16_cache = {} - for e in range(E): - w_bf16_cache[e] = unpack_e2m1_to_bf16(weights_fp4[e], scales_ue4m3[e]) # (N, K) - - # Process per expert - for e in range(E): - # Find all (token, k_idx) pairs for this expert - mask = (topk_ids == e) # (num_tokens, top_k) - if not mask.any(): - continue - - w_bf16 = w_bf16_cache[e] # (N, K) - - # Collect tokens for this expert across all top-k slots - for k_idx in range(top_k): - token_mask = mask[:, k_idx] - if not token_mask.any(): - continue - token_indices = token_mask.nonzero(as_tuple=True)[0] - - # Gather activations - x_sub = x_bf16[token_indices] # (n, K) - - # BF16 GEMM: (n, K) @ (N, K).T → (n, N) - result = torch.nn.functional.linear(x_sub, w_bf16) - - # Weighted scatter-add - weights = topk_weights[token_indices, k_idx].unsqueeze(-1) - output[token_indices] += result * weights - - return output - - -# --------------------------------------------------------------------------- -# Convenience: grouped GEMM with uint32 packed scales -# --------------------------------------------------------------------------- - -def grouped_gemm_fp4_packed_sf( - x_bf16: torch.Tensor, - weights_fp4: torch.Tensor, # (E, N, K//2) int8 - scales_packed: torch.Tensor, # (E, N, sf_k_groups) uint32 packed UE4M3 - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, -) -> torch.Tensor: - """Same as grouped_gemm_fp4 but unpacks uint32 packed UE4M3 scales first.""" - scales_fp8 = unpack_ue4m3_u32(scales_packed) - return grouped_gemm_fp4(x_bf16, weights_fp4, scales_fp8, topk_ids, topk_weights) diff --git a/src/nvfp4_megamoe_kernel/tilelang_nvfp4_gemm.py b/src/nvfp4_megamoe_kernel/tilelang_nvfp4_gemm.py deleted file mode 100644 index 0a145d60..00000000 --- a/src/nvfp4_megamoe_kernel/tilelang_nvfp4_gemm.py +++ /dev/null @@ -1,599 +0,0 @@ -""" -Native NVFP4 Block-Scaled GEMM using tcgen05.mma kind::mxf8f6f4.block_scale. - -This module provides the native NVFP4 tensor core GEMM for Blackwell (SM100). -It uses the mxf8f6f4.block_scale PTX instruction which natively performs -E2M1 × E2M1 multiplication with UE4M3 block-16 scaling in tensor cores. - -Architecture: - - A: E2M1 packed (int8, 2 values per byte) in global → SMEM via TMA - - B: E2M1 packed (int8, 2 values per byte) in global → SMEM via TMA - - SFA: UE4M3 (float8_e4m3fn) in global → SMEM via TMA → TMEM via tcgen05.ld - - SFB: UE4M3 (float8_e4m3fn) in global → SMEM via TMA → TMEM via tcgen05.ld - - C: accumulated in TMEM, stored to global memory - -The key PTX instruction: - tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [tmem_c], - desc_a, desc_b, idescE, [tsfa_addr], [tsfb_addr], pred; - -This is the native hardware path for NVFP4 block-scaled MMA on Blackwell. -No dequantization is performed — the hardware does E2M1×E2M1 with UE4M3 -block scaling natively in the tensor cores. - -Implementation Strategy: - We compile the CUDA kernel at runtime using torch.utils.cpp_extension.load. - The CUDA kernel uses inline PTX for the mxf8f6f4.block_scale instruction - and TMA for efficient global→SMEM transfers. - - For the MoE (Mixture of Experts) use case, we support grouped GEMM where - each expert has its own weight matrix and scale factors, and tokens are - routed to specific experts via top-k indices. -""" - -import os -import time -import torch -import tempfile -import subprocess -import hashlib -from typing import Optional - -# DeepSeek-V4-Pro dimensions -HIDDEN = 7168 -INTERMEDIATE = 3072 -NUM_EXPERTS = 256 -NUM_RANKS = 8 -NUM_TOPK = 6 - -# Block sizes for the GEMM tiling -BLOCK_M = 128 -BLOCK_N = 128 -BLOCK_K = 64 # For f8f6f4, atom_k=32 elements = 16 bytes packed; we use 64 for double buffering - -MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0")) - -# --------------------------------------------------------------------------- -# CUDA kernel source -# --------------------------------------------------------------------------- - -NVFP4_BLOCKSCALED_GEMM_CUDA = r""" -#include -#include -#include -#include -#include - -// PTX instructions for Blackwell tcgen05 MMA with block scaling -// We use inline PTX for mxf8f6f4.block_scale which is not yet in cuda.h - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 -#define TCGEN05_BLOCKSCALED_ENABLED 1 -#else -#define TCGEN05_BLOCKSCALED_ENABLED 0 -#endif - -// Helper: initialize tcgen05 SMEM descriptor (64-bit) -// Matches the layout from TileLang's common.h initialize_tcgen05_descriptor -__device__ __forceinline__ -uint64_t make_tcgen05_descriptor( - uint32_t smem_addr, // shared memory base address - uint32_t leading_bytes, // bytes between consecutive rows/cols in the leading dim - uint32_t stride_bytes, // bytes between tiles in the stride dim - uint32_t base_offset, // offset within the 256B swizzle block (0-7) - uint32_t leading_abs, // 1 if leading offset is absolute, 0 if relative - uint32_t swizzle_mode // 0=none, 1=32B, 2=64B, 3=128B, 4=128B_base32B -) { - uint32_t lo = (smem_addr >> 4) | (leading_bytes << 16); - uint32_t hi = stride_bytes | (1u << 14) // version = 1 - | ((base_offset & 0x7) << 17) - | (leading_abs << 20) - | ((swizzle_mode & 0x7) << 29); - return (static_cast(hi) << 32) | lo; -} - -// ============================================================================ -// Native NVFP4 Block-Scaled GEMM Kernel -// ============================================================================ -// -// Computes C = A @ B where: -// A is E2M1 packed (int8, 2 vals/byte) with UE4M3 block-16 scales (SFA) -// B is E2M1 packed (int8, 2 vals/byte) with UE4M3 block-16 scales (SFB) -// C is float32 (or bfloat16) -// -// Uses tcgen05.mma.kind::mxf8f6f4.block_scale PTX instruction. -// This performs E2M1 × E2M1 with UE4M3 block scaling natively in tensor cores. -// -// For MoE: each CTA handles one (expert, m_tile, n_tile) work item. -// ============================================================================ - -template -__global__ void __launch_bounds__(128) -nvfp4_blockscaled_gemm_kernel( - // A: (M, K_packed) int8 — K_packed = K/2 (2 E2M1 per byte) - const int8_t* __restrict__ A, - // SFA: (M, K_sf) float8_e4m3fn — K_sf = K/16 (one scale per 16 elements) - const __nv_fp8_e4m3* __restrict__ SFA, - // B: (N, K_packed) int8 — K-major layout - const int8_t* __restrict__ B, - // SFB: (N, K_sf) float8_e4m3fn - const __nv_fp8_e4m3* __restrict__ SFB, - // C: (M, N) float32 output - float* __restrict__ C, - // Dimensions - int M, int N, int K, - // Strides - int64_t stride_a_m, int64_t stride_a_k, // A row/col strides in elements - int64_t stride_sfa_m, int64_t stride_sfa_k, // SFA strides - int64_t stride_b_n, int64_t stride_b_k, // B row/col strides in elements - int64_t stride_sfb_n, int64_t stride_sfb_k, // SFB strides - int64_t stride_c_m, int64_t stride_c_n // C strides -) { - // For SM100+, we would use tcgen05.mma.kind::mxf8f6f4.block_scale - // However, the full implementation requires: - // 1. TMA descriptors for global→SMEM async copies - // 2. tcgen05.ld for SMEM→TMEM scale factor transfer - // 3. TMEM allocation for accumulators and scale factors - // 4. tcgen05.mma.kind::mxf8f6f4.block_scale PTX - // 5. TMEM→global store for results - // - // The TMA descriptor setup requires CUDA runtime APIs that are only - // available in CUDA 13.0+ driver. For now, we implement a fallback - // that does the dequantize+GEMM on tensor cores with BF16 MMA, - // and document the native path for when TMA descriptor APIs are stable. - -#if TCGEN05_BLOCKSCALED_ENABLED && 0 // Disabled until TMA APIs are stable - // Native f8f6f4 block-scaled MMA path - // This code path will be enabled once CUDA 13.0 TMA descriptor APIs - // are available in the PyTorch CUDA extension build system. - - // ... (TMA + tcgen05.mma.kind::mxf8f6f4.block_scale PTX) ... - -#else - // Fallback: Dequantize E2M1+UE4M3 → BF16, then BF16 GEMM - // This uses tcgen05.mma.kind::f16 which is native BF16 tensor core MMA. - // The dequantization is done per-tile in shared memory. - - const int tid = threadIdx.x; - const int bx = blockIdx.x; - const int by = blockIdx.y; - - const int m_start = by * BLOCK_M; - const int n_start = bx * BLOCK_N; - - if (m_start >= M || n_start >= N) return; - - const int K_packed = K / 2; // E2M1 packed: 2 per byte - const int K_sf = K / 16; // UE4M3 block16: 1 scale per 16 elements - const int BLOCK_K_packed = BLOCK_K_ELEMS / 2; - const int BLOCK_K_sf = BLOCK_K_ELEMS / 16; - - // Shared memory for A (dequantized to BF16), B (dequantized to BF16) - extern __shared__ char smem[]; - __nv_bfloat16* sA = reinterpret_cast<__nv_bfloat16*>(smem); - __nv_bfloat16* sB = reinterpret_cast<__nv_bfloat16*>(smem + BLOCK_M * BLOCK_K_ELEMS * sizeof(__nv_bfloat16)); - - // Register fragment for accumulator - float accum[BLOCK_M * BLOCK_N / 128] = {0}; // Per-thread accumulator - - // For simplicity, use wmma-style accumulation - // Each thread in the CTA cooperates on the tile - const int m_valid = min(BLOCK_M, M - m_start); - const int n_valid = min(BLOCK_N, N - n_start); - - // Process K in tiles - for (int k_start = 0; k_start < K; k_start += BLOCK_K_ELEMS) { - const int k_valid = min(BLOCK_K_ELEMS, K - k_start); - const int k_packed_valid = k_valid / 2; - const int k_sf_valid = k_valid / 16; - - // Load and dequantize A tile: (BLOCK_M, BLOCK_K) BF16 - for (int idx = tid; idx < BLOCK_M * BLOCK_K_ELEMS; idx += 128) { - int m_local = idx / BLOCK_K_ELEMS; - int k_local = idx % BLOCK_K_ELEMS; - int m_global = m_start + m_local; - int k_global = k_start + k_local; - - if (m_global < M && k_global < K) { - // Load E2M1 value - int8_t packed = A[m_global * stride_a_k + k_global / 2]; - int e2m1_val = (k_global & 1) ? (packed >> 4) & 0xF : packed & 0xF; - - // E2M1 to float: sign * 2^(exp-2) * (1 + mant*0.5) - float sign = (e2m1_val >> 3) ? -1.0f : 1.0f; - int exp_field = (e2m1_val >> 1) & 0x3; - float mant = (e2m1_val & 1) * 0.5f; - float val = sign * powf(2.0f, exp_field - 2.0f) * (1.0f + mant); - if (exp_field == 0 && !(e2m1_val & 1)) val = 0.0f; - - // Apply UE4M3 block scale - int sf_idx = k_global / 16; - __nv_fp8_e4m3 sf = SFA[m_global * stride_sfa_k + sf_idx]; - float sf_val = __nv_fp8_e4m3_to_float(sf); - - sA[m_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(val * sf_val); - } else { - sA[m_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(0.0f); - } - } - - // Load and dequantize B tile: (BLOCK_N, BLOCK_K) BF16 (K-major B) - for (int idx = tid; idx < BLOCK_N * BLOCK_K_ELEMS; idx += 128) { - int n_local = idx / BLOCK_K_ELEMS; - int k_local = idx % BLOCK_K_ELEMS; - int n_global = n_start + n_local; - int k_global = k_start + k_local; - - if (n_global < N && k_global < K) { - int8_t packed = B[n_global * stride_b_k + k_global / 2]; - int e2m1_val = (k_global & 1) ? (packed >> 4) & 0xF : packed & 0xF; - - float sign = (e2m1_val >> 3) ? -1.0f : 1.0f; - int exp_field = (e2m1_val >> 1) & 0x3; - float mant = (e2m1_val & 1) * 0.5f; - float val = sign * powf(2.0f, exp_field - 2.0f) * (1.0f + mant); - if (exp_field == 0 && !(e2m1_val & 1)) val = 0.0f; - - int sf_idx = k_global / 16; - __nv_fp8_e4m3 sf = SFB[n_global * stride_sfb_k + sf_idx]; - float sf_val = __nv_fp8_e4m3_to_float(sf); - - sB[n_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(val * sf_val); - } else { - sB[n_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(0.0f); - } - } - - __syncthreads(); - - // BF16 GEMM accumulation: C += A @ B^T - // Simple per-thread row-major accumulation (not using tensor cores in this fallback) - // In the native path, tcgen05.mma handles this natively - for (int m_local = 0; m_local < m_valid; m_local++) { - for (int n_local = tid; n_local < n_valid; n_local += 128) { - float sum = 0.0f; - for (int k_local = 0; k_local < k_valid; k_local++) { - sum += __bfloat162float(sA[m_local * BLOCK_K_ELEMS + k_local]) - * __bfloat162float(sB[n_local * BLOCK_K_ELEMS + k_local]); - } - // Atomic add to output - int m_global = m_start + m_local; - int n_global = n_start + n_local; - atomicAdd(&C[m_global * stride_c_n + n_global], sum); - } - } - - __syncthreads(); - } -#endif -} - - -// ============================================================================ -// Native NVFP4 Block-Scaled GEMM — Full Pipeline (SM100 native path) -// ============================================================================ -// -// This kernel uses the full native pipeline: -// 1. TMA async copy: E2M1 data → SMEM -// 2. TMA async copy: UE4M3 scales → SMEM -// 3. tcgen05.ld: SMEM scales → TMEM -// 4. tcgen05.mma.kind::mxf8f6f4.block_scale: native block-scaled MMA -// 5. TMEM store: results → global memory -// -// Requires CUDA 13.0+ and SM100 (Blackwell) hardware. -// ============================================================================ - -template -__global__ void __launch_bounds__(128, 1) -nvfp4_blockscaled_gemm_native_kernel( - const int8_t* __restrict__ A_packed, // (M, K/2) E2M1 packed - const uint8_t* __restrict__ SFA_packed, // (M, K/16) UE4M3 scales as uint8 - const int8_t* __restrict__ B_packed, // (N, K/2) E2M1 packed - const uint8_t* __restrict__ SFB_packed, // (N, K/16) UE4M3 scales as uint8 - float* __restrict__ C_out, // (M, N) float32 output - int M, int N, int K_total, - int64_t stride_a, int64_t stride_sfa, - int64_t stride_b, int64_t stride_sfb, - int64_t stride_c -) { - // The native mxf8f6f4.block_scale path - // This requires: - // - TMA tensor map creation (cuTensorMapEncodeTiled) for A, B, SFA, SFB - // - Shared memory with proper swizzle layout for tcgen05 descriptors - // - TMEM allocation for accumulators (C) and scale factors (SFA, SFB) - // - tcgen05.ld for scale factor SMEM→TMEM transfer - // - tcgen05.mma.kind::mxf8f6f4.block_scale for native MMA - // - tcgen05.st for TMEM→global result store - // - // The full implementation requires CUDA 13.0 driver support for: - // - cuTensorMapEncodeTiled with sub-byte types - // - TMEM allocation/deallocation intrinsics - // - tcgen05.ld/st intrinsics - // - // For now, we delegate to the fallback kernel. - // The native path will be enabled in a follow-up when the build - // system supports CUDA 13.0 headers. - - // This should never be called — we use the fallback path - assert(0 && "Native path not yet compiled — use fallback"); -} - - -// ============================================================================ -// PyTorch bindings -// ============================================================================ - -torch::Tensor nvfp4_blockscaled_gemm_forward( - torch::Tensor A_packed, // (M, K/2) int8 — E2M1 packed - torch::Tensor SFA, // (M, K/16) float8_e4m3fn or uint8 — UE4M3 block16 scales - torch::Tensor B_packed, // (N, K/2) int8 — E2M1 packed - torch::Tensor SFB, // (N, K/16) float8_e4m3fn or uint8 — UE4M3 block16 scales - int64_t M, int64_t N, int64_t K -) { - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(A_packed.device()); - auto C = torch::zeros({M, N}, options); - - const int BLOCK_M = 128; - const int BLOCK_N = 128; - const int BLOCK_K = 128; - - dim3 grid((N + BLOCK_N - 1) / BLOCK_N, (M + BLOCK_M - 1) / BLOCK_M); - dim3 block(128); - - // Calculate shared memory size: A + B in BF16 - int smem_size = (BLOCK_M * BLOCK_K + BLOCK_N * BLOCK_K) * sizeof(__nv_bfloat16); - - auto stream = c10::cuda::getCurrentCUDAStream(); - - nvfp4_blockscaled_gemm_kernel<<>>( - A_packed.data_ptr(), - reinterpret_cast(SFA.data_ptr()), - B_packed.data_ptr(), - reinterpret_cast(SFB.data_ptr()), - C.data_ptr(), - M, N, K, - K / 2, 1, // stride_a_m, stride_a_k - K / 16, 1, // stride_sfa_m, stride_sfa_k - K / 2, 1, // stride_b_n, stride_b_k - K / 16, 1, // stride_sfb_n, stride_sfb_k - N, 1 // stride_c_m, stride_c_n - ); - - return C; -} - -TORCH_LIBRARY(nvfp4_blockscaled, m) { - m.def("gemm_forward", &nvfp4_blockscaled_gemm_forward); -} -""" - -# --------------------------------------------------------------------------- -# Kernel compilation and caching -# --------------------------------------------------------------------------- - -_compiled_ext = None -_ext_lock = None - - -def _get_compiled_extension(): - """Compile and cache the CUDA extension.""" - global _compiled_ext - if _compiled_ext is not None: - return _compiled_ext - - import torch.utils.cpp_extension as cpp_ext - - with tempfile.TemporaryDirectory() as tmpdir: - cu_path = os.path.join(tmpdir, "nvfp4_blockscaled_gemm.cu") - with open(cu_path, "w") as f: - f.write(NVFP4_BLOCKSCALED_GEMM_CUDA) - - ext = cpp_ext.load( - name="nvfp4_blockscaled", - sources=[cu_path], - extra_cuda_cflags=[ - "-gencode=arch=compute_100a,code=sm_100a", - "--expt-relaxed-constexpr", - "-DNVFP4_BLOCKSCALED_ENABLED=1", - ], - extra_cflags=["-O2"], - verbose=MEGA_MOE_DEBUG, - ) - - _compiled_ext = ext - return ext - - -# --------------------------------------------------------------------------- -# Native NVFP4 GEMM API -# --------------------------------------------------------------------------- - -def nvfp4_blockscaled_gemm( - A_packed: torch.Tensor, # (M, K//2) int8 — E2M1 packed, K-major - A_scales: torch.Tensor, # (M, K//16) float8_e4m3fn — UE4M3 block16 scales - B_packed: torch.Tensor, # (N, K//2) int8 — E2M1 packed, K-major - B_scales: torch.Tensor, # (N, K//16) float8_e4m3fn — UE4M3 block16 scales -) -> torch.Tensor: - """Native NVFP4 block-scaled GEMM: C = A @ B^T. - - A is (M, K//2) int8 E2M1 packed with (M, K//16) UE4M3 scales. - B is (N, K//2) int8 E2M1 packed with (N, K//16) UE4M3 scales. - C is (M, N) float32. - - Uses the native mxf8f6f4.block_scale tensor core instruction on Blackwell. - Falls back to dequantize+BF16-GEMM on non-Blackwell hardware. - """ - M = A_packed.shape[0] - K_half = A_packed.shape[1] - K = K_half * 2 - N = B_packed.shape[0] - - assert A_packed.dtype == torch.int8, f"A must be int8, got {A_packed.dtype}" - assert B_packed.dtype == torch.int8, f"B must be int8, got {B_packed.dtype}" - assert A_packed.is_cuda and B_packed.is_cuda, "Tensors must be on CUDA" - - # Try native path - try: - ext = _get_compiled_extension() - # Ensure scales are uint8 view of float8_e4m3fn - if A_scales.dtype == torch.float8_e4m3fn: - A_sf_u8 = A_scales.view(torch.uint8) - else: - A_sf_u8 = A_scales - if B_scales.dtype == torch.float8_e4m3fn: - B_sf_u8 = B_scales.view(torch.uint8) - else: - B_sf_u8 = B_scales - - return ext.gemm_forward(A_packed, A_sf_u8, B_packed, B_sf_u8, M, N, K) - except Exception as e: - if MEGA_MOE_DEBUG: - print(f"[nvfp4_gemm] Native kernel failed, using dequant fallback: {e}") - # Fallback: dequantize and use torch.matmul - from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16 - A_bf16 = unpack_e2m1_to_bf16(A_packed, A_scales) # (M, K) - B_bf16 = unpack_e2m1_to_bf16(B_packed, B_scales) # (N, K) - return torch.matmul(A_bf16.to(torch.float32), B_bf16.to(torch.float32).t()) - - -# --------------------------------------------------------------------------- -# MoE Grouped GEMM with native NVFP4 block-scaled MMA -# --------------------------------------------------------------------------- - -def grouped_gemm_nvfp4_native( - x_packed: torch.Tensor, # (num_tokens, K//2) int8 — E2M1 packed - x_scales: torch.Tensor, # (num_tokens, K//16) UE4M3 scales - weights: torch.Tensor, # (E, N, K//2) int8 — per-expert E2M1 weights - weight_scales: torch.Tensor, # (E, N, K//16) UE4M3 per-expert scales - topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32 - topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32 -) -> torch.Tensor: - """Segmented grouped expert GEMM with native NVFP4 block-scaled MMA. - - For each expert, runs the native NVFP4 GEMM on tokens routed to it. - Results are scattered back with routing weights. - - Args: - x_packed: Packed E2M1 activations (num_tokens, K//2) - x_scales: UE4M3 block16 scales (num_tokens, K//16) - weights: Per-expert E2M1 weights (E, N, K//2) - weight_scales: Per-expert UE4M3 scales (E, N, K//16) - topk_ids: Expert assignments (num_tokens, NUM_TOPK) - topk_weights: Routing weights (num_tokens, NUM_TOPK) - - Returns: - (num_tokens, N) bfloat16 output - """ - num_tokens = x_packed.shape[0] - K_half = x_packed.shape[1] - K = K_half * 2 - E = weights.shape[0] - N = weights.shape[1] - top_k = topk_ids.shape[1] - device = x_packed.device - - output = torch.zeros(num_tokens, N, dtype=torch.float32, device=device) - - # Process per expert - for e in range(E): - mask = (topk_ids == e) # (num_tokens, top_k) - if not mask.any(): - continue - - for k_idx in range(top_k): - token_mask = mask[:, k_idx] - if not token_mask.any(): - continue - token_indices = token_mask.nonzero(as_tuple=True)[0] - - # Gather activations for this expert - x_sub_packed = x_packed[token_indices] # (n, K//2) - x_sub_scales = x_scales[token_indices] # (n, K//16) - w_packed = weights[e] # (N, K//2) - w_scales = weight_scales[e] # (N, K//16) - - # Native NVFP4 GEMM: (n, K) @ (N, K)^T → (n, N) - result = nvfp4_blockscaled_gemm( - x_sub_packed, x_sub_scales, - w_packed, w_scales, - ) # (n, N) float32 - - # Weighted scatter-add - weights_f32 = topk_weights[token_indices, k_idx].unsqueeze(-1) - output[token_indices] += result * weights_f32 - - return output.to(torch.bfloat16) - - -# --------------------------------------------------------------------------- -# Convenience wrappers for uint32 packed scales -# --------------------------------------------------------------------------- - -def grouped_gemm_nvfp4_packed_sf( - x_packed: torch.Tensor, # (num_tokens, K//2) int8 - x_sf_packed: torch.Tensor, # (num_tokens, sf_groups) uint32 packed UE4M3 - weights: torch.Tensor, # (E, N, K//2) int8 - weight_sf: torch.Tensor, # (E, N, sf_groups) uint32 packed UE4M3 - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, -) -> torch.Tensor: - """Grouped GEMM with uint32 packed UE4M3 scales.""" - from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_ue4m3_u32 - - x_sf_fp8 = unpack_ue4m3_u32(x_sf_packed) if x_sf_packed.dtype == torch.uint32 else x_sf_packed - w_sf_fp8 = unpack_ue4m3_u32(weight_sf) if weight_sf.dtype == torch.uint32 else weight_sf - - return grouped_gemm_nvfp4_native( - x_packed, x_sf_fp8, - weights, w_sf_fp8, - topk_ids, topk_weights, - ) - - -# --------------------------------------------------------------------------- -# TileLang-based NVFP4 GEMM (using T.gemm with float8_e4m3) -# --------------------------------------------------------------------------- - -_tilelang_kernel_cache = {} - - -def _make_tilelang_nvfp4_gemm(M, N, K_packed, block_M=128, block_N=128, block_K=64): - """Build a TileLang GEMM kernel using float8_e4m3 (f8f6f4) tensor cores. - - This uses TileLang's T.gemm() with float8_e4m3 dtype, which lowers to - tcgen05.mma.kind::f8f4 on Blackwell. Note: this path does NOT apply - UE4M3 block scaling natively — it does E2M1 × E2M1 without scales. - For proper NVFP4 with block scaling, use nvfp4_blockscaled_gemm(). - - The TileLang path is kept for experimentation and benchmarking. - """ - key = (M, N, K_packed, block_M, block_N, block_K) - if key in _tilelang_kernel_cache: - return _tilelang_kernel_cache[key] - - import tilelang - import tilelang.language as T - - K = K_packed * 2 # Unpacked element count - - @tilelang.jit(out_idx=[2]) - def fp4_gemm( - A: T.Tensor((M, K_packed), "float8_e4m3"), - B: T.Tensor((K_packed, N), "float8_e4m3"), - C: T.Tensor((M, N), T.float32), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), "float8_e4m3") - B_shared = T.alloc_shared((block_K, block_N), "float8_e4m3") - C_local = T.alloc_fragment((block_M, block_N), T.float32) - - T.clear(C_local) - - for k in T.Pipelined(T.ceildiv(K_packed, block_K), num_stages=2): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local, num_elems_per_byte=2) - - T.copy(C_local, C[by * block_M, bx * block_N]) - - _tilelang_kernel_cache[key] = fp4_gemm - return fp4_gemm diff --git a/src/nvfp4_megamoe_kernel/weight_transform.py b/src/nvfp4_megamoe_kernel/weight_transform.py index 02eefe70..1d61f9bc 100644 --- a/src/nvfp4_megamoe_kernel/weight_transform.py +++ b/src/nvfp4_megamoe_kernel/weight_transform.py @@ -1,10 +1,10 @@ """ -NVFP4 Weight Transformation for TileLang mega_moe kernel. +NVFP4 Weight Transformation for CUTLASS mega_moe kernel. Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float32 global scale) -into the TMA-aligned format expected by the block-scaled GEMM kernel: -- Packed FP4 weights (uint8, K-major) -- Packed UE4M3 scales (uint32, TMA-aligned UTCCP layout) +into the format expected by the CUTLASS block-scaled GEMM kernel: +- Packed FP4 weights (int8, K-major) +- UE4M3 block scales (float8_e4m3fn, row-major — CUTLASS SF remap handles layout) This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe. @@ -70,13 +70,13 @@ def _interleave_l1_weights(weight: torch.Tensor) -> torch.Tensor: return interleaved.contiguous() -def transform_nvfp4_weights_for_tilelang( +def transform_nvfp4_weights_for_mega_moe( l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale) l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale) l1_weight_scale_2: torch.Tensor = None, # float32 global scale for L1 l2_weight_scale_2: torch.Tensor = None, # float32 global scale for L2 ) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: - """Transform NVFP4 weights for the TileLang block-scaled GEMM. + """Transform NVFP4 weights for the CUTLASS block-scaled GEMM. Matches the call signature from nightly vLLM deepseek_v4.py finalize_weights. @@ -113,7 +113,3 @@ def transform_nvfp4_weights_for_tilelang( l2_weight_out = l2_weight.contiguous() return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out) - - -# Alias for drop-in replacement -transform_nvfp4_weights_for_mega_moe = transform_nvfp4_weights_for_tilelang diff --git a/src/quantize.mojo b/src/quantize.mojo deleted file mode 100644 index 71a7126f..00000000 --- a/src/quantize.mojo +++ /dev/null @@ -1,95 +0,0 @@ -""" -NVFP4 quantization utilities — E2M1 packing and UE4M3 scale handling. -Core math layer for the NVFP4 mega_moe kernel rewrite. -""" - -# E2M1 magnitude lookup table (positive values only) -# Index 0-7 maps to: 0, 0.5, 1, 1.5, 2, 3, 4, 6 -def e2m1_magnitude(index: Int) -> Float64: - if index == 0: return 0.0 - if index == 1: return 0.5 - if index == 2: return 1.0 - if index == 3: return 1.5 - if index == 4: return 2.0 - if index == 5: return 3.0 - if index == 6: return 4.0 - if index == 7: return 6.0 - return 0.0 - - -def quantize_e2m1(value: Float64) -> UInt8: - """Quantize a float64 value to E2M1 (4-bit), returning the 4-bit nibble with sign.""" - var sign = 0 - var abs_val = value - if value < 0.0: - sign = 1 - abs_val = -value - - # Find best E2M1 match - var best_idx = 0 - var best_err = abs_val # error for idx=0 - - for i in range(1, 8): - mag = e2m1_magnitude(i) - err = abs(abs_val - mag) - if err < best_err: - best_err = err - best_idx = i - - return (sign << 3) | best_idx - - -def unpack_e2m1(packed: UInt8, idx: Int) -> Float64: - """Unpack one E2M1 value from a packed byte. - idx=0 -> low nibble, idx=1 -> high nibble. - """ - nibble: UInt8 - if idx == 0: - nibble = packed & 0x0F - else: - nibble = (packed >> 4) & 0x0F # keep sign bit - - sign = (nibble >> 3) & 1 - mag_idx = nibble & 0x07 - magnitude = e2m1_magnitude(Int(mag_idx)) - - if sign: - return -magnitude - return magnitude - - -def dequantize_nvfp4_weight( - packed_weight: UInt8, - block_scale: Float64, - group_offset: Int, -) -> Float64: - """Dequantize a single NVFP4 weight element. - weight = E2M1_magnitude * block_scale - (global_scale is already folded into block_scale) - """ - e2m1_value = unpack_e2m1(packed_weight, group_offset) - return e2m1_value * block_scale - - -def main() raises: - # Test E2M1 quantization round-trip - print("E2M1 quantization test:") - for val in [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]: - packed = quantize_e2m1(val) - unpacked = unpack_e2m1(packed, 0) - print(" ", val, " -> E2M1 -> ", unpacked) - - # Test packed byte (two E2M1 values) - print("\nPacked byte test:") - lo = 1.5 - hi = -3.0 - packed = (quantize_e2m1(hi) << 4) | quantize_e2m1(lo) - print(" lo=", lo, " hi=", hi, " packed=", packed) - print(" unpack lo=", unpack_e2m1(packed, 0), " unpack hi=", unpack_e2m1(packed, 1)) - - # Test NVFP4 dequantization - print("\nNVFP4 dequantization test:") - packed_w = UInt8(0x36) # low=6.0, high=3.0 - scale = 0.5 - print(" packed=0x36, scale=0.5, lo=", dequantize_nvfp4_weight(packed_w, scale, 0)) - print(" packed=0x36, scale=0.5, hi=", dequantize_nvfp4_weight(packed_w, scale, 1)) diff --git a/src/staging.py b/src/staging.py deleted file mode 100644 index 42a9ec4d..00000000 --- a/src/staging.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -NVFP4 Activation Quantization — BF16 → packed FP4 + UE4M3 block16 scales - -Port of patches/staging_kernel.py to TileLang. - -This quantizes BF16 hidden states to: -- Packed FP4 (E2M1, 2 values per uint8 byte) -- UE4M3 block16 scales (4 values per uint32 word, group_size=16) -""" - -import torch -import tilelang -import tilelang.language as T - - -@tilelang.jit -def fp4_quant_kernel( - X, # (M, K) bfloat16 input - block_M=128, - block_K=128, - group_size=16, # NVFP4 UE4M3 group_size -): - """Quantize BF16 → packed FP4 (E2M1) + UE4M3 block16 scales. - - Output: - - X_FP4: (M, K//2) uint8 packed E2M1 - - X_SF: (M, K//group_size//4) uint32 packed UE4M3 scales - """ - M, K = T.const("M, K") - X: T.Tensor[[M, K], "bfloat16"] - - X_FP4 = T.empty((M, K // 2), "uint8") - X_SF = T.empty((M, K // (group_size * 4)), "uint32") - - with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=128) as (bx, by): - # Each block quantizes a (block_M, block_K) tile - tx = T.get_thread_binding() - - row = bx * block_M + tx // (block_K // 2) - col = tx % (block_K // 2) - - if row < M and col < K // 2: - # Load 2 BF16 values - val_lo = X[row, col * 2] - val_hi = X[row, col * 2 + 1] - - # Quantize to E2M1 (4-bit each, pack 2 per byte) - # ... (E2M1 quantization logic) - packed = T.uint8(0) # placeholder - X_FP4[row, col] = packed - - # Compute block scale for group of 16 elements - # ... (UE4M3 scale computation) - - return X_FP4, X_SF - - -def fp4_quant_reference(x_bf16, group_size=16): - """Reference FP4 quantization using PyTorch + DeepGEMM. - - Used for correctness verification of the TileLang kernel. - """ - from deep_gemm import per_token_cast_to_fp4 - x_fp4, x_sf = per_token_cast_to_fp4(x_bf16) - return x_fp4, x_sf diff --git a/src/weight_transform.py b/src/weight_transform.py deleted file mode 100644 index 528dbdb5..00000000 --- a/src/weight_transform.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -NVFP4 Weight Transformation for TileLang mega_moe kernel. - -Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float32 global scale) -into the TMA-aligned format expected by the block-scaled GEMM kernel: -- Packed FP4 weights (uint8, K-major) -- Packed UE4M3 scales (uint32, TMA-aligned UTCCP layout) - -This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe. -""" - -import torch -import math - - -def fold_global_scale( - weight_scale: torch.Tensor, # (N, K//16) float8_e4m3fn - weight_scale_2: torch.Tensor, # (num_logical,) or scalar float32 - logical_widths: list[int] = None, # per-logical-weight row counts -) -> torch.Tensor: - """Fold global scale into block scales: UE4M3 * FP32 → UE4M3 → float32. - - Returns: (N, K//16) float32 folded block scales. - """ - sf_f32 = weight_scale.to(torch.float32) - - if weight_scale_2.numel() == 1: - sf_f32 = sf_f32 * weight_scale_2.to(torch.float32) - elif weight_scale_2.numel() > 1 and logical_widths is not None: - # Per-logical-weight global scale (e.g., gate_up_proj has 2) - expanded = [] - for i, w in enumerate(logical_widths): - if i < len(weight_scale_2): - expanded.append(weight_scale_2[i].flatten()[0].expand(w)) - global_scale = torch.cat(expanded).to(torch.float32).unsqueeze(1) - sf_f32 = sf_f32 * global_scale - else: - sf_f32 = sf_f32 * weight_scale_2.max().to(torch.float32) - - return sf_f32 - - -def pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor: - """Pack 4 UE4M3 (float8_e4m3fn) values into one uint32. - - Input: (..., K//16) float8_e4m3fn - Output: (..., K//64) uint32 (4 values packed per word) - """ - # View as raw bytes - sf_u8 = sf.view(torch.uint8) # (..., K//16) uint8 - shape = sf_u8.shape - assert shape[-1] % 4 == 0, f"Last dim {shape[-1]} not divisible by 4" - - # Pack 4 consecutive uint8 values into one uint32 - packed = (sf_u8[..., 0::4].to(torch.int32) | - (sf_u8[..., 1::4].to(torch.int32) << 8) | - (sf_u8[..., 2::4].to(torch.int32) << 16) | - (sf_u8[..., 3::4].to(torch.int32) << 24)) - - return packed.contiguous() - - -def interleave_l1_weights( - weight: torch.Tensor, # (E, 2*INTER, K//2) int8, K-major -) -> torch.Tensor: - """Interleave L1 (gate_up) weights for 2CTA UMMA. - - The gate and up projections are interleaved in groups of 8 rows - to match the UMMA 2CTA schedule. - """ - E, N, K_half = weight.shape - assert N % 16 == 0, f"N={N} not divisible by 16" - - # Reshape to (E, N//16, 16, K_half) → interleave pairs of 8 - w = weight.view(E, N // 16, 16, K_half) - # Interleave: [gate_0..7, up_0..7, gate_8..15, up_8..15, ...] - # Actually: pairs of 8 rows from gate and up - w_gate = w[:, :, :8, :] - w_up = w[:, :, 8:16, :] - interleaved = torch.stack([w_gate, w_up], dim=3).reshape(E, N, K_half) - - return interleaved.contiguous() - - -def transform_nvfp4_weights_for_tilelang( - weight: torch.Tensor, # (E, N, K//2) int8, K-major packed FP4 - weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn UE4M3 - weight_scale_2: torch.Tensor, # (E, 1) or (E, 2) float32 global scale - N: int, # output dimension (2*INTER for L1, HIDDEN for L2) - K: int, # input dimension (HIDDEN for L1, INTER for L2) - gran_k: int = 16, # NVFP4 group_size - is_l1: bool = False, # True for gate_up_proj (needs interleaving) - logical_widths: list[int] = None, -) -> tuple[torch.Tensor, torch.Tensor]: - """Transform NVFP4 weights for the TileLang block-scaled GEMM. - - Returns: - transformed_weight: (E, N, K//2) int8, K-major, contiguously interleaved - transformed_sf: (E, N, K//64) uint32, packed UE4M3 with folded global scale - """ - device = weight.device - - # Step 1: Fold global scale into block scales - sf_folded = fold_global_scale(weight_scale, weight_scale_2, logical_widths) - - # Step 2: Clamp and convert back to UE4M3 - sf_clamped = sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) - - # Step 3: Pack UE4M3 into uint32 - sf_packed = pack_ue4m3_to_uint32(sf_clamped) # (E, N, K//64) - - # Step 4: Interleave L1 weights if needed - if is_l1: - transformed_weight = interleave_l1_weights(weight) - else: - transformed_weight = weight.contiguous() - - # Step 5: Ensure K-major layout (already K-major from checkpoint) - # The weight should be (E, N, K//2) with stride (N*K//2, K//2, 1) for K-major - - return transformed_weight, sf_packed - - -def main(): - """Test the weight transformation with dummy data.""" - E = 4 # small test - N = 64 - K = 128 - - device = "cpu" # No GPU on sandbox - weight = torch.randint(-128, 127, (E, N, K // 2), dtype=torch.int8, device=device) - weight_scale = torch.randn(E, N, K // 16, device=device).abs().to(torch.float8_e4m3fn) - weight_scale_2 = torch.ones(E, 2, dtype=torch.float32, device=device) - - transformed_w, transformed_sf = transform_nvfp4_weights_for_tilelang( - weight, weight_scale, weight_scale_2, N, K, is_l1=True, - logical_widths=[32, 32], - ) - - print(f"Weight: {transformed_w.shape} {transformed_w.dtype}") - print(f"SF: {transformed_sf.shape} {transformed_sf.dtype}") - print(f"Expected SF: (E={E}, N={N}, K//64={K//64})") - print("Weight transformation test passed!") - - -if __name__ == "__main__": - main()