From d3f35c9465c7c75f47091f38d69a72f06e1bddcf Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 14 May 2026 12:44:47 +0000 Subject: [PATCH] cleanup: remove abandoned TileLang and Mojo files - Deleted: layout.mojo, mega_moe.mojo, quantize.mojo (Mojo attempt) - Deleted: nvfp4_blockscaled_gemm.py, staging.py, nvfp4_mega_moe.py (TileLang top-level) - Deleted: tilelang_nvfp4_gemm.py, tilelang_kernels.py, nvfp4_dequant.py (TileLang package) - Deleted: src/weight_transform.py (duplicate of package version) - Fixed nvfp4_mega_moe.py: inlined unpack_ue4m3_u32, removed TileLang fallback imports - Fixed weight_transform.py: renamed function, removed TileLang alias, updated docs - Fixed __init__.py: removed TileLang alias, updated docstring - CUTLASS is the only kernel path now --- src/layout.mojo | 32 - src/mega_moe.mojo | 24 - src/nvfp4_blockscaled_gemm.py | 256 -------- src/nvfp4_mega_moe.py | 163 ----- src/nvfp4_megamoe_kernel/__init__.py | 4 +- .../__pycache__/__init__.cpython-312.pyc | Bin 754 -> 0 bytes .../__pycache__/nvfp4_dequant.cpython-312.pyc | Bin 3883 -> 0 bytes .../nvfp4_mega_moe.cpython-312.pyc | Bin 7430 -> 0 bytes .../__pycache__/symm_buffer.cpython-312.pyc | Bin 3431 -> 0 bytes .../tilelang_kernels.cpython-312.pyc | Bin 6192 -> 0 bytes .../weight_transform.cpython-312.pyc | Bin 6637 -> 0 bytes src/nvfp4_megamoe_kernel/nvfp4_dequant.py | 71 --- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 64 +- src/nvfp4_megamoe_kernel/tilelang_kernels.py | 136 ---- .../tilelang_nvfp4_gemm.py | 599 ------------------ src/nvfp4_megamoe_kernel/weight_transform.py | 16 +- src/quantize.mojo | 95 --- src/staging.py | 65 -- src/weight_transform.py | 147 ----- 19 files changed, 36 insertions(+), 1636 deletions(-) delete mode 100644 src/layout.mojo delete mode 100644 src/mega_moe.mojo delete mode 100644 src/nvfp4_blockscaled_gemm.py delete mode 100644 src/nvfp4_mega_moe.py delete mode 100644 src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc delete mode 100644 src/nvfp4_megamoe_kernel/__pycache__/nvfp4_dequant.cpython-312.pyc delete mode 100644 src/nvfp4_megamoe_kernel/__pycache__/nvfp4_mega_moe.cpython-312.pyc delete mode 100644 src/nvfp4_megamoe_kernel/__pycache__/symm_buffer.cpython-312.pyc delete mode 100644 src/nvfp4_megamoe_kernel/__pycache__/tilelang_kernels.cpython-312.pyc delete mode 100644 src/nvfp4_megamoe_kernel/__pycache__/weight_transform.cpython-312.pyc delete mode 100644 src/nvfp4_megamoe_kernel/nvfp4_dequant.py delete mode 100644 src/nvfp4_megamoe_kernel/tilelang_kernels.py delete mode 100644 src/nvfp4_megamoe_kernel/tilelang_nvfp4_gemm.py delete mode 100644 src/quantize.mojo delete mode 100644 src/staging.py delete mode 100644 src/weight_transform.py diff --git a/src/layout.mojo b/src/layout.mojo deleted file mode 100644 index af284d20..00000000 --- a/src/layout.mojo +++ /dev/null @@ -1,32 +0,0 @@ -""" -NVFP4 weight transformation and SF layout utilities. - -Port of deep_gemm.mega.transform_nvfp4_weights_for_mega_moe -""" - -from math import ceil_div - -fn fold_global_scale_into_block_scales( - weight_scale: Tensor[float8_e4m3fn], # (N, K//16) UE4M3 block scales - weight_scale_2: Tensor[float32], # (num_logical,) or scalar global scale - logical_widths: List[int], # per-logical-weight row counts -) -> Tensor[float32]: - """Fold global scale into block scales: UE4M3 * FP32 -> FP32""" - # Convert UE4M3 to float32, multiply by global scale - # For MergedColumnParallelLinear, expand per-logical global scale - ... - -fn pack_ue4m3_to_int32(sf: Tensor[float8_e4m3fn]) -> Tensor[int32]: - """Pack 4 UE4M3 values (4 bytes) into one int32 for DeepGEMM TMA""" - # View as uint8, pack 4 consecutive bytes into int32 - ... - -fn transform_sf_into_required_layout( - sf_mn: Tensor[int32], # MN-major packed SF - N: int, K: int, - recipe: Tuple[int, int], # (gran_mn, gran_k) - num_groups: int, -) -> Tensor[int32]: - """Transform SF into TMA-aligned UTCCP layout for DeepGEMM""" - # Call into DeepGEMM's C++ layout transform - ... diff --git a/src/mega_moe.mojo b/src/mega_moe.mojo deleted file mode 100644 index d9056f5d..00000000 --- a/src/mega_moe.mojo +++ /dev/null @@ -1,24 +0,0 @@ -""" -NVFP4 Mega MoE Kernel — Mojo Rewrite - -This is a from-scratch rewrite of the DeepGEMM fp8_nvfp4_mega_moe kernel. -The CUDA C++ version crashes on B200 with CUDA_ERROR_LAUNCH_FAILED in the -SM100_MMA_MXF4NVF4 instruction. This Mojo rewrite aims to: -1. Produce a correct, working NVFP4 mega_moe kernel -2. Be more maintainable than 1200+ lines of CUDA template metaprogramming -3. Leverage Mojo's GPU programming model for cleaner TMA/UMMA integration - -NVFP4 format: -- Weights: E2M1 packed (2 x 4-bit values per byte), int8 container -- Block scales: UE4M3 (float8_e4m3fn), group_size=16 -- Global scale: float32 scalar -- Activation: FP8 e4m3fn with UE8M0 per-token scales - -Key differences from MXFP4 (group_size=32, UE8M0): -- 2 SF K-columns per BLOCK_K (128/16/4=2) instead of 1 -- mxf4nvf4 UMMA instruction with scale_vec::4X -- Different weight transformation (gran_k=16 vs gran_k=32) -""" - -# TODO: Implement as Mojo GPU kernel -# Waiting for Mojo GPU programming docs / SDK setup diff --git a/src/nvfp4_blockscaled_gemm.py b/src/nvfp4_blockscaled_gemm.py deleted file mode 100644 index 135c65bc..00000000 --- a/src/nvfp4_blockscaled_gemm.py +++ /dev/null @@ -1,256 +0,0 @@ -""" -NVFP4 Block-Scaled GEMM — TileLang -2CTA persistent kernel for Blackwell SM100 - -Adapted from tilelang/examples/blockscaled_gemm_sm100/gemm_mxfp8_blockscaled_1d1d.py - -Key NVFP4 differences from MXFP8: -- sf_granularity_k=16 (UE4M3 block16) instead of 128 (UE8M0) -- 2 SF uint32 words per BLOCK_K instead of 1 (sf_load_period=1, not 4) -- sf_a_id/sf_b_id cycle through 0,1,2,3 within each SF word -- Must call tcgen05_gemm_blockscaled twice per K-block with different SF words -- in_dtype="float4_e2m1fnx2" (packed FP4) instead of "float8_e4m3fn" (FP8) -""" - -import torch -import tilelang -import tilelang.language as T -from tilelang.carver.arch import driver -from tilelang.profiler import do_bench - -# NVFP4 constants -NVFP4_SF_GRAN_K = 16 # UE4M3 group_size -NVFP4_SF_PACK = 4 # 4 UE4M3 values per uint32 -NVFP4_SF_K_PER_WORD = NVFP4_SF_GRAN_K * NVFP4_SF_PACK # 64 K-elements per uint32 word -# For BLOCK_K=128: 128/64 = 2 SF words per K-block -# Each word covers 4 UMMA sub-atoms (sf_id 0-3) - - -@tilelang.jit -def nvfp4_blockscaled_gemm_2cta_persistent( - A, # (M, K) packed FP4 - B, # (N, K) packed FP4 (K-major for NN) - SFA, # (sf_k_groups * M) uint32 packed UE4M3 - SFB, # (sf_k_groups * N) uint32 packed UE4M3 - block_M=128, - block_N=256, - block_K=128, - in_dtype="float4_e2m1fnx2", - out_dtype="bfloat16", - accum_dtype="float32", - num_stages=2, - sf_granularity_k=NVFP4_SF_GRAN_K, - use_tma_store=True, - store_block_N=64, -): - M, N, K = T.const("M, N, K") - - assert block_M == 128 - assert block_N == 256 - assert block_K == 128 - - half_N = block_N // 2 - k_iters = T.ceildiv(K, block_K) - - # NVFP4 SF layout: - # sf_granularity_k=16, pack_factor=4 → 64 K-elements per uint32 - # BLOCK_K=128 → 2 SF words per K-block - # sf_load_period=1 (load every K-block, unlike MXFP8 which loads every 4) - sf_words_per_k_block = block_K // (sf_granularity_k * NVFP4_SF_PACK) # 2 - sf_k_groups = T.ceildiv(K, sf_granularity_k * NVFP4_SF_PACK) # total SF groups across K - - A: T.Tensor[[M, K], in_dtype] - B: T.Tensor[[N, K], in_dtype] - SFA: T.Tensor[[sf_k_groups * M], T.uint32] - SFB: T.Tensor[[sf_k_groups * N], T.uint32] - C = T.empty((M, N), out_dtype) - - sm_num = driver.get_num_sms() - num_clusters = sm_num // 2 - m_blocks = T.ceildiv(M, block_M) - m_clusters = m_blocks // 2 - n_blocks = T.ceildiv(N, block_N) - waves = T.ceildiv(m_blocks * n_blocks, sm_num) - group_size = 16 - assert n_blocks % (2 * group_size) == 0 - - with T.Kernel(sm_num, threads=256, cluster_dims=2) as (block_id): - cta_id = T.block_rank_in_cluster() - T.assume(cta_id < 2) - - # Shared memory — FP4 packed (K is halved because 2 values per byte) - A_shared = T.alloc_shared((num_stages, block_M, block_K // 2), "uint8") - B_shared = T.alloc_shared((num_stages, block_K // 2, half_N), "uint8") - - # Shared memory for SF — 2 uint32 words per K-block for NVFP4 - SFA_shared = T.alloc_shared((num_stages, block_M, sf_words_per_k_block), T.uint32) - SFB_shared = T.alloc_shared((num_stages, block_N, sf_words_per_k_block), T.uint32) - - # Tensor memory - C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) - SFA_tmem = T.alloc_tmem([block_M, block_M // 128 * 4], T.uint32) - SFB_tmem = T.alloc_tmem([block_M, block_N // 128 * 4], T.uint32) - - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) - C_shared = T.alloc_shared((block_M, store_block_N), out_dtype) - - loaded = T.alloc_barrier([32] * num_stages) - with_sf_full = T.alloc_cluster_barrier([32 * 2] * num_stages) - consumed = T.alloc_cluster_barrier([1] * num_stages) - tmem_full = T.alloc_cluster_barrier([1]) - tmem_empty = T.alloc_cluster_barrier([128 * 2]) - - tx = T.get_thread_binding() - warp_idx = tx // 32 - - if warp_idx == 0: - # Warp 0: TMA load - for w in T.serial(waves): - cluster_id = block_id // 2 - tile_id = num_clusters * w + cluster_id - bx_cluster = (tile_id // group_size) % m_clusters - bx = bx_cluster * 2 + cta_id - by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size - - if bx * block_M < M and by * block_N < N: - for k in T.serial(k_iters): - phase = w * k_iters + k - stage = phase % num_stages - parity = (phase // num_stages) & 1 - - T.mbarrier_wait_parity(consumed[stage], parity ^ 1) - - # TMA load A (packed FP4) - T.tma_copy( - A[bx * block_M, k * block_K], - A_shared[stage, :, :], - barrier=loaded[stage], - ) - # TMA load B (packed FP4) - T.tma_copy( - B[k * block_K, by * block_N + cta_id * half_N], - B_shared[stage, :, :], - barrier=loaded[stage], - ) - # TMA load SFA — 2 words per K-block for NVFP4 - # Word 0 covers K[k*128 : k*128+64] - # Word 1 covers K[k*128+64 : k*128+128] - for sf_word in T.unroll(sf_words_per_k_block): - sf_group = k * sf_words_per_k_block + sf_word - T.tma_copy( - SFA[sf_group * M + bx * block_M : - sf_group * M + (bx + 1) * block_M], - SFA_shared[stage, :, sf_word], - barrier=loaded[stage], - ) - # TMA load SFB — 2 words per K-block - for sf_word in T.unroll(sf_words_per_k_block): - sf_group = k * sf_words_per_k_block + sf_word - T.tma_copy( - SFB[sf_group * N + by * block_N : - sf_group * N + (by + 1) * block_N], - SFB_shared[stage, :, sf_word], - barrier=loaded[stage], - ) - T.mbarrier_arrive(loaded[stage]) - - elif warp_idx == 1 and cta_id == 0: - # Warp 1: MMA issue + UTCCP - for w in T.serial(waves): - cluster_id = block_id // 2 - tile_id = num_clusters * w + cluster_id - bx_cluster = (tile_id // group_size) % m_clusters - bx = bx_cluster * 2 + cta_id - by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size - - if bx * block_M < M and by * block_N < N: - for k in T.serial(k_iters): - phase = w * k_iters + k - stage = phase % num_stages - parity = (phase // num_stages) & 1 - - T.mbarrier_wait_parity(with_sf_full[stage], parity) - - # NVFP4: 2 SF words per K-block → 2 sub-GEMM calls - # Each sub-GEMM covers 64 K-elements (BLOCK_K/2) - for sf_word in T.unroll(sf_words_per_k_block): - # UTCCP: copy SF word from shared to tensor memory - T.tcgen05_cp_warpx4( - SFA_shared[stage, :, sf_word], - SFA_tmem, - use_2cta=True, - ) - T.tcgen05_cp_warpx4( - SFB_shared[stage, :, sf_word], - SFB_tmem, - use_2cta=True, - ) - - # Block-scaled GEMM - # For NVFP4: sf_a_id selects which of 4 packed UE4M3 - # values in the uint32 word to use. - # With sf_granularity_k=16 and UMMA_K=64: - # 4 SF values per UMMA atom, sf_id=0 covers all 4 - # (the hardware auto-cycles through 0,1,2,3 internally) - T.tcgen05_gemm_blockscaled( - A_shared[stage, :, :], - B_shared[stage, :, :], - C_tmem, - SFA_tmem, - SFB_tmem, - transpose_B=True, - mbar=consumed[stage], - clear_accum=(k == 0 and sf_word == 0), - sf_a_id=0, - sf_b_id=0, - use_2cta=True, - ) - - T.tcgen05_mma_arrive(tmem_full, arrive_2cta=True) - - elif warp_idx == 2: - # Warp 2: SF transpose - for w in T.serial(waves): - cluster_id = block_id // 2 - tile_id = num_clusters * w + cluster_id - bx_cluster = (tile_id // group_size) % m_clusters - bx = bx_cluster * 2 + cta_id - by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size - - if bx * block_M < M and by * block_N < N: - for k in T.serial(k_iters): - phase = w * k_iters + k - stage = phase % num_stages - parity = (phase // num_stages) & 1 - - T.mbarrier_wait_parity(loaded[stage], parity) - # Transpose all SF words for this K-block - for sf_word in T.unroll(sf_words_per_k_block): - T.tcgen05_sf_warp_transpose(SFA_shared[stage, :, sf_word]) - T.tcgen05_sf_warp_transpose(SFB_shared[stage, :, sf_word]) - T.fence_proxy_async() - T.mbarrier_arrive(with_sf_full[stage], 0) - - # Epilogue: write C from tmem to global memory - for w in T.serial(waves): - cluster_id = block_id // 2 - tile_id = num_clusters * w + cluster_id - bx_cluster = (tile_id // group_size) % m_clusters - bx = bx_cluster * 2 + cta_id - by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size - - if bx * block_M < M and by * block_N < N: - T.mbarrier_wait_parity(tmem_full, w & 1) - T.copy(C_tmem, C_local) - T.copy(C_local, C_local_cast) - T.copy(C_local_cast, C_shared) - - if use_tma_store: - T.copy(C_shared, C[bx * block_M, by * block_N]) - else: - T.copy(C_shared, C[bx * block_M, by * block_N], disable_tma=True) - - T.mbarrier_arrive(tmem_empty, 0) - - return C diff --git a/src/nvfp4_mega_moe.py b/src/nvfp4_mega_moe.py deleted file mode 100644 index ee0e26f1..00000000 --- a/src/nvfp4_mega_moe.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -NVFP4 Mega MoE Kernel — Full MoE with expert parallelism. - -This is the main kernel that replaces fp8_nvfp4_mega_moe from DeepGEMM. - -Architecture: -- L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with UE4M3 scales) -- SiLU+Mul activation -- L2 GEMM: down_proj (FP4 x FP4 → BF16 with UE4M3 scales) -- NVLink cross-rank sync via symm buffer -- Expert parallel: each rank handles NUM_EXPERTS/8 experts - -The kernel is written in TileLang, compiled to SM100 (Blackwell) CUBIN. -""" - -import torch -import tilelang -import tilelang.language as T -from tilelang.carver.arch import driver - -# DeepSeek-V4-Pro dimensions -HIDDEN = 7168 -INTERMEDIATE = 3072 -NUM_EXPERTS = 256 -NUM_RANKS = 8 -NUM_TOPK = 6 - -# Block sizes for the block-scaled GEMM -BLOCK_M = 192 # tokens per tile -BLOCK_K = 128 -BLOCK_N = 128 - -# NVFP4 scale parameters -SF_GRANULARITY_K = 16 # UE4M3 group_size -SF_PACK_FACTOR = 4 # 4 UE4M3 values per uint32 -SF_WORDS_PER_K_BLOCK = BLOCK_K // (SF_GRANULARITY_K * SF_PACK_FACTOR) # 2 - - -@tilelang.jit -def nvfp4_mega_moe_l1( - # Activation (packed FP4, from staging kernel) - X, # (num_tokens, K//2) int8 packed E2M1 - X_SF, # (num_tokens, sf_k_groups) uint32 packed UE4M3 - # L1 weights (pre-transformed for mega_moe) - L1_W, # (num_experts_per_rank, 2*INTERMEDIATE, K//2) int8 K-major - L1_SF, # (num_experts_per_rank, 2*INTERMEDIATE, sf_k_groups) uint32 TMA-aligned - # Routing - TopkIdx, # (num_tokens, NUM_TOPK) int32 - TopkW, # (num_tokens, NUM_TOPK) float32 - # Output - Y, # (num_tokens, 2*INTERMEDIATE) bfloat16 - num_experts_per_rank, -): - """ - L1 GEMM for mega_moe: gate_up_proj - X(FP4) @ L1_W(FP4)^T → Y(BF16) with UE4M3 block scaling - - This handles the gate_up_proj for all experts in the rank. - Each token is routed to NUM_TOPK experts, and the GEMM is computed - per expert group using persistent scheduling. - """ - num_tokens = T.const("num_tokens") - K = T.const("K") - INTER = T.const("INTERMEDIATE") - - k_iters = T.ceildiv(K, BLOCK_K) - sf_k_groups = T.ceildiv(K, SF_GRANULARITY_K * SF_PACK_FACTOR) - - with T.Kernel( - T.ceildiv(num_tokens, BLOCK_M), - T.ceildiv(2 * INTER, BLOCK_N), - threads=256, - cluster_dims=2, - ) as (bx, by): - cta_id = T.block_rank_in_cluster() - T.assume(cta_id < 2) - - # ... (persistent scheduling, expert routing, L1 GEMM with block scaling) - # This follows the same pattern as nvfp4_blockscaled_gemm_2cta_persistent - # but adds expert scheduling and topk weight scaling - - pass # Full implementation in progress - - -@tilelang.jit -def nvfp4_mega_moe_l2( - # L1 output (quantized to FP4 for L2 input) - X, # (num_tokens, INTER//2) int8 packed E2M1 - X_SF, # (num_tokens, sf_k_groups) uint32 packed UE4M3 - # L2 weights - L2_W, # (num_experts_per_rank, HIDDEN, INTER//2) int8 K-major - L2_SF, # (num_experts_per_rank, HIDDEN, sf_k_groups) uint32 TMA-aligned - # Routing - TopkIdx, # (num_tokens, NUM_TOPK) int32 - TopkW, # (num_tokens, NUM_TOPK) float32 - # Output - Y, # (num_tokens, HIDDEN) bfloat16 - num_experts_per_rank, -): - """ - L2 GEMM for mega_moe: down_proj - X(FP4) @ L2_W(FP4)^T → Y(BF16) with UE4M3 block scaling - - After SiLU+Mul on the L1 output, the result is quantized to FP4 - and fed into L2. - """ - pass # Symmetric to L1 - - -def nvfp4_mega_moe_full( - hidden_states, # (num_tokens, HIDDEN) bfloat16 - topk_weights, # (num_tokens, NUM_TOPK) float32 - topk_ids, # (num_tokens, NUM_TOPK) int32 - l1_weights, # L1 weights (transformed for mega_moe) - l1_scales, # L1 UE4M3 scales (transformed) - l2_weights, # L2 weights (transformed for mega_moe) - l2_scales, # L2 UE4M3 scales (transformed) - symm_buffer, # NVLink symm buffer for cross-rank sync -): - """ - Full mega_moe forward pass: - 1. Stage: quantize BF16 hidden_states → FP4 + UE4M3 scales - 2. L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with block scaling) - 3. SiLU + Mul (activation) - 4. Quantize L1 output → FP4 + UE4M3 scales - 5. L2 GEMM: down_proj (FP4 x FP4 → BF16 with block scaling) - 6. NVLink sync + reduce across ranks - """ - # Step 1: Stage activation (BF16 → FP4 quantization) - # This is the staging kernel from patches/staging_kernel.py - x_fp4, x_sf = stage_activation(hidden_states) - - # Step 2: L1 GEMM - l1_output = nvfp4_mega_moe_l1( - x_fp4, x_sf, l1_weights, l1_scales, topk_ids, topk_weights) - - # Step 3: SiLU + Mul - gate, up = l1_output.chunk(2, dim=-1) - activated = torch.nn.functional.silu(gate) * up - - # Step 4: Quantize L1 output → FP4 - l1_fp4, l1_sf = stage_activation(activated) - - # Step 5: L2 GEMM - l2_output = nvfp4_mega_moe_l2( - l1_fp4, l1_sf, l2_weights, l2_scales, topk_ids, topk_weights) - - # Step 6: NVLink reduce - output = nvlink_reduce(l2_output, topk_weights, symm_buffer) - - return output - - -def stage_activation(x_bf16): - """Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales. - - This replaces the Triton staging kernel from patches/staging_kernel.py. - """ - # E2M1 quantization with UE4M3 block scaling - # For now, use PyTorch reference implementation - # TODO: Write as TileLang kernel for full pipeline integration - from deep_gemm import per_token_cast_to_fp4 - return per_token_cast_to_fp4(x_bf16) diff --git a/src/nvfp4_megamoe_kernel/__init__.py b/src/nvfp4_megamoe_kernel/__init__.py index c592bdb5..05cdc12d 100644 --- a/src/nvfp4_megamoe_kernel/__init__.py +++ b/src/nvfp4_megamoe_kernel/__init__.py @@ -1,4 +1,4 @@ -"""NVFP4 Mega MoE Kernel — TileLang implementation for DeepSeek-V4-Pro on Blackwell.""" +"""NVFP4 Mega MoE Kernel — CUTLASS implementation for DeepSeek-V4-Pro on Blackwell.""" from nvfp4_megamoe_kernel.nvfp4_mega_moe import ( nvfp4_mega_moe_full, @@ -7,7 +7,7 @@ from nvfp4_megamoe_kernel.nvfp4_mega_moe import ( stage_activation, ) from nvfp4_megamoe_kernel.weight_transform import ( - transform_nvfp4_weights_for_tilelang as transform_nvfp4_weights_for_mega_moe, + transform_nvfp4_weights_for_mega_moe, ) from nvfp4_megamoe_kernel.symm_buffer import ( SymmBuffer, diff --git a/src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc deleted file mode 100644 index 71a6b920f4b71f6c57dd88ccb159f361534f4cac..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 754 zcmaJ_0u>|q+4D3r@6CEY4~HJo@%F+#|3nb_W|HgBc^jKU z{d$iQ6ruztgkwQMBCOC7c4!MHbgV8Clh)ESwxTI6Z#BP0i1q){f@rKIXsVwjOL=`AfTwDI+q1arJ{!brsIw@r*W(3{jTOl<91K>dwt~^ zRdZDudTYatg*%7FUc@OT@PS#Wr)_Vq}tpXxxvg@rNxf?i(Q4kjC<;pZ;FAGJYcIg5D diff --git a/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_dequant.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_dequant.cpython-312.pyc deleted file mode 100644 index f38c12b5b6d37d57060fc56f46612407e699dd18..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3883 zcmb_eZ%i9U7T;alYkSwm-~?R@0VZj3hk=;b5K3rsJ;v?}-M`jPu`I_d5MTURn#wNiBViEo^^Pt{4iH@kKq zK!cQYq}`b}Gw;ot_ujnU`-j`jQP95qpEY^YNm2hMM7^=K!EPA_3zSIRp+s6_QuG~K zv|u`Et#91xMeEV?QLAYCz4eae2l?2e`F>CyE80Px^&aDMW%?cSc zkrUK}ltEcFkxHlutoS+ZijLL=7$1AcCsTyI=wO;(+iv6_`L5qj^|6tR>AzTFE9u~eQ)?q98KRAJ3* zM_KMsjiW&4X_3wWb&QOyIMIrV?8lE$OI_ z_WC@Nx8$vPTMT0EGxkx<>IIge@@&$G_PxDi6eIp}Jv~P#(JtCrs6Wx7ZI1q&{Ueq1 z9NzP&HkE8TytdEZ-ovYPU%gHF9NBMS{dR@~7|KFW_66&4YDQ*fKbjU&SqRE0EZ27U zNPtFDlAsQTaDTclnt?D!(1zuhGGvShWd`jK>h$~lJxGa$CPOh<%1$Z1`bqyP&ZGX? z6R4kJJ`!>mPT@4psHW&!M&$34en=|nLjAs{)iPAs{vK|)O`NseNPaJo-}lfS(z-=V zq{n@n&Zv@ZQzbbP*KLLobQ>|opl(xCP)N7SScwZ$7-*y8A`(3ISXRm^x^+5%@9OTS z?BTP^Bu;cowlAnN;S2Bu;3MNcod?NSPGMmhi;5k#5o1C7?w7rBDUExjDV&Lu_D^1`^6eVmzR3?Pv#YKZ*V@}_;Y}0(dOi3RKTzd|HGa6l zU#aq=8bA8a=IV_J?Z!lfe}B{Z{w~~)(pfW+MxLZSn0%5bKLppjV4`I5Of9_;{~c)i z9^VluqW<^_7LZ{i%H;FQurk!XgmgrLA?YZ_h?dW+^<;EZx`8yrBpVJ~14z*(T8uPg zG1E}NmbZy4NuAs&)Gn-BM0jR7UVgN#LN=dDcB8~0MN5-@(lR}N1heA zy#4tlh+GaVnFHUMn&urMPew#z(ifb0yXcziq^P>KCS}}!%RBR&=LPhyoj^ao{slhI zCr=zOKaiq(e;gECdE1dbx4cX|{-PjzJf#&053X&K9nZ4n4$8HGq3k;IO|_f(8X78ep31qIEj%*3D}l#=eEOd=dk;mb7;LKk{aJQ0f@ zoDTR6x%;ZPS`9aZK$)k}$XqdS)WAU`MG@3|poD@WP%%#9I^>bEks!je@N}%V7cOet zWnW=wun|@>0oK76HU|a{fa8&dekRzC`B2z^xxj#N&Nit^wn-^2lhh;g@EhMIz!!s$ zw2j(e=Z$Uf`jhEj)TST5A(IR(lQik~@p3aENFI{O1DD<8N6wH*i6fIrMkaBs+YK!M zTAO4%g#|UVhp4)vuA;BOz+f2>4kkOOGbu@D;tAbhq9Yb{D~TQ536T?uCU6QSNZ{xe zhy&f3!?F}g3(6#*Ni-CN(|G7%uaFUemmX7yro`5JO>QHHov6Vt4yFSkRSJcpfdTm< ztek>R`4}t#2z-X(8fRKdt@F1R-kW=`;yewA$+Z#|;%mx@=nnqDryHP^QM ziq?3pcw@_Ta#ksgn}B>jyHr>#RQU77t00N@%)V0gE%hw+JjkvVRtlAti(gu_me(}z z>l$}?7`&}zVXl4ooppL;VqMVAzN$42!U<4z_-0!cJaZnX zEZi&ZN~2%n{02wE$(eUc@6Klz3Uh^u`&{w*miMKqw@dSOtp(SEo8GR9cc3`-*wHxC zP->W!HAhQvc+1;XjzCp0yc{gYH1AJg-Q}6>)?95>{D9r*27Xqx^H zW2LRTE{b;VupasYx_oTl9LUc^-D{m!}2^(MpoE56vtZ55usfWjst zFbPIr1xJQWu=MFjIN({6smVGM&a5lpVzJ$sac652wOLQX1MjYcS8ykMLT$n?coKDj zHxUqgiJ;(5gdB_*7TJZ<_H^av!r5}!E(vu);E}_8d&HX0MB{6Y5jflvD%ZzlVHz$gizKDUS#nJjbJvICaTsVwNzSJ=F{$MhalqG2#(N0143Ifq6S;hj z%PH~#iQrOK2Fv;*yISmkzPtFti z#XK(wus||#GtLdaKQ%l(J9DAm60C~C6m9wCRp>UG&WPi@G}lFvayADq0?}kL6Ysfr zkwmV5Xcw2oOeRXM-n=q80WzG?6cEVS`wzI_|cCOIcc7vJIL#B7$klaZ5)R89g( z<`qShw9G0|^SPX?Xex=nKQh&iV(urMWVk57$!Jf-F#Qr%m`A95aXucR)x_ zjEu{?V9$f69!X1DfApm(BYCL5)El){OW;i=>O8N&^0H!9R#xeNy)jTZds>twRUpvX zBVA1`E9qBWO-pmu)~NJskOXL!XNkZuJPV>X5|ClNlc$=w-wSybRO zWcxsfERpgl!f5toO3H99NsWNEfE3{l-VyD)gcWcXCD7<4k=_Hkq;Bg(@+9&sl<(U|Wk0@Fm)VKn9? zJata&&GvAb%-u=#TsFd1WnS#dUgGj_4h|zsPhJ#Et%Iv?v{%<~rv7khD$R5-Oq!`< zn5PV2;8PFG_7q|IX&5s=DaHuJhp!KD@yTIsW_D`?@ z*TC=jpTLNljLt0BEfr6JJppg)wXbM(&{mJ`5o*?4(9f$ojvdA9=hG}4XBeebgZ^d` zoO<>f1LhOYBMa|YMrRjnt7Lx$PYvZj(*o|MDlJA}9={f_t98~M zPhxd;-LHj8EsXBBZOnR|Uh_XNS53uSHR}PbzBE^z?$86eQ+GYGB(NUTgELIjlmA=T z?(wirjm;~Ny?o{18Un~6FifR$V1Ohr6!-=(QE{uwfPd6XymXDhbeQ&n4ssM4Oh|jJ z0RRZ5M5WuZl9sMl-~qk2fQA~Pi3OSFZe}O1O%9N0(^&rZ?sxfi+sI8fVIlqRyXAB# zhM!k%%@mdYq@0^Zdv3qJs$vkEsByl_@ai4VKy=N!fP9_!G~VM z1xN}sAqbxuuABlmY`6q%H76P_O$HNhxJ{&1@J=guKNVDyXvlD_aEKDl6~F_-m+9dw z85v$En+IihG(r3KZfqb0jbla$YLrsP!)rsvSw8a}Kw*K4@PxOqO4GT1SLoiIE& z2~5{-xK~UF^DLKY;D}HhczaMF73sx|3-fYTyddX9DVgDyF9_n&1uCBIth^v!P?ep>^}*#>uVsHr^=waWT+v z|JJ=*TlM#D7emzQRz^)9~2ScHn8-sz30YEDZg+w)u%Cupd6Ry|$g%J+sU2 z&g|9f4eece#;|{MqIjZvzb*Exh7Dc(iebGy&pgce*r)9e+y5^A6aB~fPor1A486JS z{4#X*ADap{o&?Bt^N!~*&{G(B=4LMRe=0we|GxQ)j?X)O8jE93@0X$T`!nCcu3y_q zZ7=Lb_HKNZISkAcMtHaiUqK=^Kw5ZddI5OXC;%yF_Z2iMM**m}1VBmy_GrYc3A1L%q9NrZ zmVrz#03p2<5YlTwhyumoIX! zauMop<44v-ZvQ`jk(DmkPSRp4P&AnpWmj*&3;59JKQmc&on;vZuo<0sWc3EiU$gdC z?^#sew1*WY8rU$U^=OTN)+4tyod|07OD5W~`;;J;yh`*rXO*-DG+Nd8Ikc;m_eq`0 z9IqwkMfHN!W??aaZ-t?YvPLu73u+~}B(G{nYQYWxgZ=taCX$?rslW13J&D2Ld`2!vlCe}w7WiW*;)V$msuz=ReKzUd1Mt?X4S2;-+^_md)Ixs zgO>c1t0I-CO2DGs?aX=rm@MK%RkB)x(25vQm8{kfw4%JLlGPfn*!OtFzH9Zsar__l z>d}L*Xz}WySG4$a?<-pTy8rl=Pka>;J@#wMqp8&}lF^yUQw z;HsoMe-yHNuQzK5%rxoEb}e07%X+JZP;Iug=&ib6udCW=8|WT_cvaHd?2%CWqT&31 zfcnvpS2;)}E6cos0SHy45wQ~X2oSlVu5AoE!{x6Ia9M~oG&KrwF$zR9MrKF6jthxF z@G3*p^c*<%5Cp?8i2&5xhei>hZfSfx4klI1siL^ZE%gDC!NOqZ0n?4ZelM>hvf4$; zWE#pPtJp>3R-75P>Y~%&5XYo4SaCpmv$?@oEY?L*h?cOod3=q&!ZDi`A$~(4PpAmR znbAKCn_1ICh$N#x*HvOiX9RKw!cs?IJxbRIj@t1~uDPM0rUuu1up`s8!#Y%0`UGViq%`p9e?kIlT_dEk_{j1r(PpIuCNbRLg{70C z(grUIT1M)9Du{oE3m3*2?25?=ZU1q_WPs>#8Ls4fURqQT3Jr%O8NO6rf|#)^@fpLZ zrZaiyfl5#9H05?sA?H@P=m~>eHCiC%Dya}J%R;QHe9@}vOP9{@S7^NvE-RjSvkY(P zR+-dv0QZACGn_cM;es6eV<%o?cmZ0QDs$c{a!WmDQ+Ev~hS3d2K4*BXIYq&6!(ixh zQx7}EL02gD;u9TUl*ln)U{;xArtzpo(4$pu!gqfLKXnf%Ao#I;Ohe;l>qhHV{YE?b zvBCTA-+Lc?*t)>|>-Vk~hJRh#^3VPgPwMKyM{VvXcAVRKcU##x_37IW-`+j-@cd54 z?t9>}Hui(d8u-R9J>L7Z_iF##3obA?rH#>##tPT&o-H;+%)&Dt6KK01zZc)?+#1-g zJ69Mk`dc3RPagPBZjEgv_WhB*1}P_+HCr&M$nQ`}TBL<~ZA4Y(M?O*vGN`6VaWPUFT1HKlbhD zdzU|Z>x=WBpZ}}_ee6jVo$M(VyzCD0!-0@*#QW4jFp;Y)?lVr}bw?U`cJ*ET0NCif%#4<-+rzH!(4m79Upoj7c}bl_|*jNBcC ziu%(SJMeDxZ8yI>x6A#f*FU;`cl2k@7RouJj)@V_@>-4)KHdNbai%e+Ni;R zRIMm2r4?B+YUV_(0x=K{=mj^gI_mL5h6;je`wP#z;0)w5;vXs^G@*Z^{taYs)-3xQ zM=R@m-pjDQUodt5#?1Z;a~sst7c97op~ga8F<4*lgT@!UMSoqv15HgW`%Ji4-%tnw z*KcfjaA$L2V_{R;koM1B-m5umdH1kkun_ureba;K&BR7x^Y+H={kQsdlZVX%hxLC@ k2>!O|Ewk1%9bokv~V?2$CO1Xa)hK+s~}1(tAQmYaA3n>(HzcqOX)=3 zk=;8gBFmtD7^qS75xy8BDXhGrLVfs|ypn%E6r4iDHUUyJAOC1bL4mwFd&d(ACaKju z-0sZI?Ck99%LwB>1E7%kanG zk7GF(HwNzV*;pl%8m#`~&5C7VmzcV8yPV5mqU4H1S-ZJn+IdAMMaM}K4M43>NE2+7 zb*z}SV&Qd7wTd_+%hz$Kpy}9A42(egJQ zR8dN7c}9NB?6O7&SC^I-)z#OR)GJG`-dLvG^SW;^zwcNN9gYItKp4UOAa*y|uo1e8 zvV0|+;;Gy{iS{R0(luDeN`YZmQ{tU&X+)gL-= zcS|^|i%AYFtT5ra3+KOG>7L&^@VqA}tz@?6701=qfk)F_bIaRHxp?)BIWUZ#$42Uy zQFU@Z#Ovr<0jz)V#?svC3;m1SVkxhh25bUQI`1h&T+q-k!2tr%9@jhAT;Fh=fRhx< zu4%6;o#}wpx7@r^whc@Q6~^(a8&?)Po)~bg!|Kw!QX)mz#Zf%eqv?^J+hjo(hX1ex zeFN3F8d8tRwGd#zb)#AsuucdeRO10-E4@||0OOA_5ilXZ+#s;gNrQE-v-I6aO)$7B z?D?)#6FX4U9S9_E4hK+o1T6(4#Mboed%8;xpjzZ_;k}N=pe-aLzdcpD)FXE+Sk)c- zJ+2QNeI4B}_#FZ$WCP4*jN^(0;3IFH}Wk8xqh$Rc4vQ%P9-hK_Bj< zh=r=uNdpV1GgR<`p2qC5rTRJ6p@WvTt+ENXn8&t51rVj?X>7wZ3~Z~8S;aKrWlswm zrsiUQ3fPFgL3zDwXmn676M}75^%bQ8wn}b=^04e@l-s6a8K4&~Q=X|y1s{CQn@(wQK2+Jvi1s7oj0Ab-ko?#tx zV8WuA4m&wDuX-kMRR4H1(K`oXF>zQy!k!hseYZ5TQMB+(v4m~CpuIB#N1X9q+_Y8P zEY3JY_fhWwhPPbJV`AgNjMIIk;YFpAs+zXxs%rIA_YyO~B=cRz1z)&>z8)FB$NyPw zO@O{R8l?w^}$O5?3$ z`r(EBeqm& znanVT>H74}7BGx3hA1$coqQ0xA7fIM82wWH(oPQOXBm9}=;zKqxOSfj@agG#dRGMc zIiP<5YQ)i*WFueCdulKNnR;e-6d2Ah0kLmo6d(UGo_O!|ohu(-{pjk)D<7@wy!>$V z6aKU4r_s-bJ{@Yl^y@a_UgIWPljk2;_pO7;+5O4c=DEucGoPP6SiHW!c>Q4U*8bwH z<|}UjW0p&{k_uDlAer7zrkfL)wh%h?{I@7HFxw8pZiirew3ZUs9od?NRh7yrlx0S_ zz~H#5{;sSQI+Td28bw`I2@CSSBJvyz8X(&6ZoFzYjVJU`~8 zR3sRxGh1g|L6ZRR-eZ=68i@Ib-)ih_4#Dfj*Mz;!iZ@Y;tNwAZP$)uugcVI;PWZj5 zlofqzH(50k)QpPnvpwZhZw=X3k)EUkWwDduJd>ov{~enS)7_6X&yMu`(>EY9ux1&4 zM}h7RIvPXq(UvrD@AvPo?Jl*XQ_YbJhe9|mwozCRJ=3S+ggq`aRCEZd5@^CpU>U=7 zplZKehJujBdJ&!>G|W;>1lwCCDcUp)s(DGszig%jVGmN*3%-+n={wHL<%f()+uJ))Kr?ZAds4=|4^w?+5VU+?Vni5Ms#Bm>X0V>a|u$XYTEX^cRt6Y zW^~eC%kTI5d!P4tpZEE=KRO&X1nJNJG>v>xh0xz{Lr+w$vKA!}nnNraL@dD?Lc}1U zmWDxtS{lQqL6X2@#*jH|8MK70gH{3s(O?N%%9^ecgEp4rh!K|#yk4^AXY^a1(Qkc5 zf5|iYZN>Zj6l=d~8l?Vbx~SrK?EjB02hio1G`O9yPYnp!27)23hv5Ts-?{G7t#mIJ zVCddR7k!EocrGN;UnM@EJG-0q)5p7dd+G6@G)&`pEO#Zw@KSJ!k%AH4W3vqm2Sqv@ zVPheVjtY^nAj^rgG|bVyh`?!!M{|(=h@g%f(kvMNG>IJ+JcRT(1iDSpQ`o4Par_C>)=o{#Iv$yNmiH?D;y(&TDTre;!iS!=& zY*%Y<3oZJX5GP`wW_mXp8RvDB!li(vf&Nzv+tHXHaJ)n_QJ{qJ4Y$+B^y71QbAZ7t z+KW%c(?X0d;MiD@p>-a6G%f+7fGGm!14N#s{W0E$naG4Z^c%d7%TK2nmKA9pI6TJD zk}tsVuk7=L!;Eb-$g}P3;R%16f4{%gGZc#WM!o8RX)zX!Mg)n*Jnm}lZQ7%9V=q0< z?Fxlx5$?}7OiK}44i#Uey1YSM&eFZT9qQUJGZ75O!f+xX5(^ARW0HEuaJQV$z)e6d zX_p!k;hq_KyT}E?z$1>`Uce?1lOjb%IbNiPCTVTmUfU1@RB)^Y08qIBIwl4|5Or=f z_#)vbNGGdtGbcSB+avtE``iYlGR8+4I62oGZt_YI?~uP~zfzX#jB%~u7H_Pj*=Jc- z2M8-8BmCFi2bFUTT_dg;t{JbHu94Twb0~#UM9Ppdrc5a^W%gU*7TF@@7mp#KWW6O@ z>R_ZFNf=E{BPPAJUa~}hxNW!vd*Q?yr93SoR;{lvhy4wHgJjnGMpV_Qx3uyWyjh%B z({r#4DdKx$k11pTtLAGs}4 zU7fLDh*cki51tqYGNlVb!E4flK$TFeTLSPhc$y7@%N&w&H-ykgj=wOPPX(hz?l@&nmFl=h*@z=>|VLJTNo%~4P+J`Ub*a7>|?5ZEKHILrtft5}Epp$H>^VlLJC{Q{#J z3dN!hv@|QEFT^o|QgS*Nr-YWiJ=r)M33H8+D98Ik z%y=WqjWzOP{%C7MI5HAx6a}Bw^oGG<3P(7vrkNTgFjygQr*du40|zO=hRKUhGqv3t zW=1*HQmGxDLn^)E5l?>3iYP5GI{yJP7n}{w6Y~d_U3+ey`Q*aQ z3yW9o>{@Q<{PO6E<8&5LC9h@%E-jsXYw5zp`!%(5@65h4KfY*L+S&H`wl7}(?B(T} zuH@^_9&pv%bM4Hyb}m>JYnFB${JiOlgP$E-c6B9>KcpPf$5Y2=j;~Pl535~s7iTa2 zy8JhFAJ;83Ed-Xg?Z4Z;Qr-6$86B-N#Qn1JIrFS}=1RIE9k@|He`Vpo$3IxAYtGcP zES_H}JF-4ro__1b=J_EQ4;2}on6F(p@TanttOa17*pQ=ix;xc9b0FQ39-2RJBec-D zSaY*yX=_`i{@@DL{!Ja6{u|7KuW|OR=l~Y}&TwpBH!}VKk=-R0w^?!IZp5ov7saYy zlwuqSN`ezsDCPk$BjA`})8H-p5 zTG)vPog402Xn@|Wc=B%{`&R-zuv1B9rhMjj`mOoE!qk$jJ#lo^V!IyBST^5x)~EY3 z&TUK9Z4XJuwM*|_N;joB8S@xX2 z?S&_yP@1#tRedgCZ6rDh##dA*wU(Qym<3LX3A_)qX+B|p2y3{t--895Loyo453N^J z>qy4UGEsCC*2gVzYrI4*VF|2h<2KnQn@4v8ryVB|Ymm(!g6}2UK19DV>7K}s2-#Q+ zjWuCCAGZS~MLdBxCFQ9fDPo~z68xLIH7aJGW$HG0D>ozV2%>nYWYdfd#Ho zYt;q(EMoIUidfDO;N8X~SIgx(&*CoGBvf@kXXG=C^i)x ztPBh}v^snk0A`2^IpGy!Ls;eF3s5XMc8j$@neOUrEH$fDmJ?akJ;B;tI16iDgTIJl z<^;-aMRhNHMBSh=rs~AW8FIC9>%GbsGnFsC+mq->TuJU+wU*A5r)-IiRXa69zTduN zI+5NoADa&>HFVu&?)ER$oLDxUNVcz1RnvW`zI0QD+LGv6byQAYPF+rPKXg<~Urb$0 zZ<;@IoBM>n$=^M<;^;|qXRWBpmFT(esF-Q`$@#~KAm1R8#QjYT3!`_o-Q@3tzHIs| zvb?D;NzI&j-{wZ)BApQ=`kuT*`MrMt=sYoX1+Q-d}C|!vn5k}DBp1h)tB)kR zVP?&a`NNAZE$ujvakbrZb!J?hcbk`8-O1ytRK?8UWop~}xrIRHrF|J{-wL(u>n%W2 zwi1q!ZAMf@qF3<1ohXi+>BS#HIJUTL-47EBV_Z4^MFOHkEca>xjZ1+2GRme~IU5K1 zjF{@~XTlJSi@GYy%LP3YlW6T@i%7?UywuW6Yg9b-1&;Oy1yS<24Fa}$sxZ0D1+V5D z?k#vVx%2p=Z6a0+nviOuxoM+ldbRz$qF=ZGbAJeb5qoXmIG3S{?eh)G<*y`8-nVZ~ zch8?%+;r#sU3u9)uw)ugdFM6>I8a^B9${X%OBIXP%R&U>6|j02+>j~dy1VGX&V3BR zAH}NIH0N=<8Xas12RvG5?qjK{V>C8^B}S@0DD1&C{KV8F+(asA>UqDl9@4(29ueM# zajf>ldys)>6U5hsQo^)GBf{|)_lrBHcg!m(!FOjKu&5p6`pDpfw+lj_MD zq5h!UHCI1dpCzFV!dF>4w|91L)`DABH3x3JLzEC_ mi99TBaZ|J#W}0tsOP35|*4G>M@=Lc0M%my)CelKnxjBjXu6HWNF{jFXa` zsHWR)9a;&8trWBEilFTZs#YqMRw|F|(?0BrUGFN^h=-{Avb>p1pDJy4&z85gh)I_0!2n zCx`u`{p#2LJ~GV4F9ukW4+dDqV+OrADNHfs#OP3WfQ^J>42^sO??~pg873&e`^77= z$VE0Dywoi#(&Wp-qoXA^51YHm(}Cb6kVNC$=#-Z{)jb`E#yL+#N^waxRPw#Q&%42c z$%$jfPLpilN<1NW%;t%y2+RO8!v=y3PtpuCLxq{?X&;&|pP)`C6`Kvs_^DF=6pVsW zClt+OfQ3=R00TmTV@MbvOe8N4EQS__y+Y28jEu_Gd4{<}&H8;aSIAH#2GcUnsAdnF zNoe1u&rVqqNq4fnC3P=ZCMjhFyj4t*DI})P5@eAD2iPzto9goI%ux+Iby@L2sI~*fVvGgU$_+xS9@Q z3p&B1+7$`a>*LskvlwgN z@kx1gw4x@u1sBZ$*AHW@)nj?47tjiCVdY!&fxp|2)<24-ZBr99>c}c%@@Nt-g4l^T zOK$*6QpQRd1Qqn;50}7oAK)G!>zBX1p%xsSgtliV`XuOB#g37Cmt&poIC0F&7V5E$U0G?7EfG|r_)xB}8 z&>LD$z8j_}i50DweaCfbk@{fr)`goF^0uzziGrzd)wCmL+VR2uTSsml$(uaMqiY7! z+pk@H?Kku3@i$&CxEeFNmzq=hVk>5?OO6(IVrJ|0rp2arcBBK>S~FwWmIHoKeQVKaH_(9J_DzsNWHR+|sznh)N0<(mgmBZY?M z44Z4%mpW0jD?XVphwFN1G4#$@`b&l;HqD4S!&5#%)8oBM~j*#Wb*co zH8i?g-kaW)+58S)&gDz#*DS8&=sL_ACJ}rnSdF%2EH*3sg6=kHq1X7(F%0dr0 z8(%{`^GpU zQji}=;gJ(P5EeaZf``W-X=fwM9ETF5q?7AZNykqGW*82wu%wf_mCV6-Oo)UN@dWSD zaHPxu081Q{^fPj4$s?11P~1+$YoO3qDuHTTN{Ca6^IQ+``r*fa1CJ!OX0=~`dGY1< zhL=w*oyuFu{>MrOuN$6>O zlZ0MCSx#=XAWPJV`Y=HoJ~4ibb|ti0aD)56fDJDhI7W#bFOo8H&%>XD4@>x%q(4PX1=tYRj6Uh-DCC%= zode$_3&dE-XPQTmDrfp~Fm0nMe5!oD;JL#fi7FxfPw;?fs^7jmxHNcmBsrA4oZ4Hk zSX12Mf%N#bfvnb_*$un@Saw%`w#lE@`cnf1qy4&N(UKlYU(Os|np+`OPUMYU$)SSH ze(my&w)B~FAbsZduGFs=e+3EOV!bwg15fwDXFq(rkh3%=k3XzyT<%)x%FM0^nb^Iy zzt{O!>jramgZIZD;13Su>t0EoDpxr{cu(Uy8xLD5tvQf}$_9>4(=(E7L*2z8;FgiqW9J4CJ%}Dh&;wzuWGVxdO{u)v%Llv~ z*KGpqcNQ)(xe?Gl2|xZn@PI?!UTpWioVGQ2GWB|)%~PVF#4&Eq8rmN=dhYgQ2Sy45 zFFkPI_kN|pn|_KfT7lj4+-jpI*Xa4A5T;Fisd1BQ$o^y`pooACL3(l5rE&G)oYv{Ur zDDU2%94)wauDUyO?#{e>UvhNKx;=9!Yw5UayLTn~+)&msl+_N&DanJ6d5Ezwk3}+2 z6deyzl#X%LykrdRLddu7Auby8wp|^uZd{>6K#olC#7?5`ppNb!e>xe9FGve(l z`|k|i9z?A6d&8fPd^Unu?Zp=Cxvs+AeUDB1@S`}Yv4>G_D;;;dx4l`nFK6#5Y7lF! zCQFA{XoKur7MH}LL4Gx=ud>ae8N5vGPdCWt?ZhqdrkHhg torch.Tensor: - """Unpack uint32 packed UE4M3 (4 values per uint32) to float8_e4m3fn. - - Args: - packed: (..., sf_k_groups) uint32 — 4 UE4M3 values packed per element - - Returns: - (..., sf_k_groups * 4) float8_e4m3fn - """ - u32 = packed.to(torch.int32) - b0 = (u32 & 0xFF).to(torch.uint8) - b1 = ((u32 >> 8) & 0xFF).to(torch.uint8) - b2 = ((u32 >> 16) & 0xFF).to(torch.uint8) - b3 = ((u32 >> 24) & 0xFF).to(torch.uint8) - interleaved = torch.stack([b0, b1, b2, b3], dim=-1) - return interleaved.reshape(*packed.shape[:-1], -1).contiguous().view(torch.float8_e4m3fn) - - -def unpack_e2m1_to_bf16( - packed: torch.Tensor, # (..., K//2) int8 — two E2M1 values per byte - scales: torch.Tensor, # (..., K//16) float8_e4m3fn — UE4M3 block16 scales -) -> torch.Tensor: - """Dequantize packed E2M1 with UE4M3 block16 scales to BF16. - - E2M1 format: sign(1) exponent(2) mantissa(1), bias=2 - Each int8 byte contains 2 E2M1 values: low nibble=element 0, high nibble=element 1. - UE4M3 block scales: one float8_e4m3fn scale per group of 16 consecutive elements. - - Args: - packed: (..., K//2) int8 packed E2M1 - scales: (..., K//16) float8_e4m3fn UE4M3 block16 scales - - Returns: - (..., K) bfloat16 - """ - u8 = packed.view(torch.uint8) - lo = (u8 & 0x0F).to(torch.int32) # lower nibble - hi = (u8 >> 4).to(torch.int32) # upper nibble - - # Interleave: (..., K//2, 2) → (..., K) - unpacked = torch.stack([lo, hi], dim=-1).reshape(*u8.shape[:-1], -1) - - # E2M1 → float32 - sign = (unpacked >> 3).to(torch.float32) * -2.0 + 1.0 - exp_field = (unpacked >> 1) & 0x3 - mant = (unpacked & 0x1).to(torch.float32) - - # E2M1 value = sign * 2^(exp - 2) * (1 + mant * 0.5) - val = sign * (2.0 ** (exp_field.to(torch.float32) - 2.0)) * (1.0 + mant * 0.5) - - # Zero: exp=0 and mant=0 - zero_mask = (exp_field == 0) & ((unpacked & 1) == 0) - val = val * (~zero_mask).to(torch.float32) - - # Apply UE4M3 block16 scales - sf_f32 = scales.to(torch.float32) - sf_expanded = sf_f32.repeat_interleave(16, dim=-1) - - K = unpacked.shape[-1] - sf_expanded = sf_expanded[..., :K] - - return (val * sf_expanded).to(torch.bfloat16) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 3990d409..881ac113 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -26,12 +26,24 @@ avoiding the costly dequantization step. import os import torch -from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32 -from nvfp4_megamoe_kernel.tilelang_nvfp4_gemm import ( - nvfp4_blockscaled_gemm, - grouped_gemm_nvfp4_native, - grouped_gemm_nvfp4_packed_sf, -) +def unpack_ue4m3_u32(x_u32): + """Unpack uint32 packed UE4M3 scales to float8_e4m3fn. + + Each uint32 contains 4 UE4M3 values packed in bits [0:8], [8:16], [16:24], [24:32]. + """ + x_u32 = x_u32.contiguous() + M, N = x_u32.shape + out = torch.empty(M, N * 4, dtype=torch.float8_e4m3fn, device=x_u32.device) + # Vectorized unpack: extract 4 bytes from each uint32 + b0 = (x_u32 & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) + b1 = ((x_u32 >> 8) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) + b2 = ((x_u32 >> 16) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) + b3 = ((x_u32 >> 24) & 0xFF).to(torch.int32).to(torch.float8_e4m3fn) + out[:, 0::4] = b0 + out[:, 1::4] = b1 + out[:, 2::4] = b2 + out[:, 3::4] = b3 + return out # CUTLASS native NVFP4 block-scaled GEMM (SM100 Blackwell) # Primary path: uses CUTLASS MainloopSm100TmaUmmaWarpSpecializedBlockScaled @@ -97,21 +109,11 @@ def nvfp4_mega_moe_l1( x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales - # Use CUTLASS native block-scaled GEMM if available - if MEGA_MOE_USE_CUTLASS and _CUTLASS_AVAILABLE: - output = cutlass_grouped_nvfp4_gemm( - x_fp4, x_sf_fp8, - l1_weights, w_sf_fp8, - topk_ids, topk_weights, - ) - else: - # Fallback to TileLang path - output = grouped_gemm_nvfp4_native( - x_fp4, x_sf_fp8, - l1_weights, w_sf_fp8, - topk_ids, topk_weights, - ) - + output = cutlass_grouped_nvfp4_gemm( + x_fp4, x_sf_fp8, + l1_weights, w_sf_fp8, + topk_ids, topk_weights, + ) return output # (num_tokens, 6144) bfloat16 @@ -141,21 +143,11 @@ def nvfp4_mega_moe_l2( x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales - # Use CUTLASS native block-scaled GEMM if available - if MEGA_MOE_USE_CUTLASS and _CUTLASS_AVAILABLE: - output = cutlass_grouped_nvfp4_gemm( - x_fp4, x_sf_fp8, - l2_weights, w_sf_fp8, - topk_ids, topk_weights, - ) - else: - # Fallback to TileLang path - output = grouped_gemm_nvfp4_native( - x_fp4, x_sf_fp8, - l2_weights, w_sf_fp8, - topk_ids, topk_weights, - ) - + output = cutlass_grouped_nvfp4_gemm( + x_fp4, x_sf_fp8, + l2_weights, w_sf_fp8, + topk_ids, topk_weights, + ) return output # (num_tokens, 7168) bfloat16 diff --git a/src/nvfp4_megamoe_kernel/tilelang_kernels.py b/src/nvfp4_megamoe_kernel/tilelang_kernels.py deleted file mode 100644 index 1ab67e19..00000000 --- a/src/nvfp4_megamoe_kernel/tilelang_kernels.py +++ /dev/null @@ -1,136 +0,0 @@ -""" -TileLang NVFP4 Mega MoE Kernels — BF16 GEMM with FP4 dequantization. - -This module provides the core GEMM kernels for the DeepSeek-V4-Pro MoE layer: -- L1 (gate_up_proj): HIDDEN→2*INTERMEDIATE, FP4 weights + UE4M3 scales -- L2 (down_proj): INTERMEDIATE→HIDDEN, FP4 weights + UE4M3 scales - -Current approach: Dequantize FP4→BF16, then run BF16 GEMM via TileLang. -This is correct and functional. Once TileLang adds native tcgen05.mma -kind::mxf8f6f4.block_scale support for E2M1+UE4M3, we'll switch to -native FP4 block-scaled MMA for maximum throughput. - -The per-expert GEMM uses a "segmented" approach: sort tokens by expert, -batched GEMM per expert using TileLang-compiled BF16 kernels. -""" - -import torch -import tilelang -import tilelang.language as T - -from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32 - -# --------------------------------------------------------------------------- -# TileLang BF16 GEMM kernel (auto-detects Blackwell, lowers to tcgen05) -# --------------------------------------------------------------------------- - -_kernel_cache = {} - - -def _make_bf16_gemm(M, N, K, block_M=128, block_N=128, block_K=128, num_stages=3): - """Build and cache a TileLang BF16 GEMM kernel for the given dimensions.""" - key = (M, N, K, block_M, block_N, block_K, num_stages) - if key in _kernel_cache: - return _kernel_cache[key] - - @tilelang.jit(out_idx=[2]) - def bf16_gemm( - A: T.Tensor((M, K), T.bfloat16), - B: T.Tensor((K, N), T.bfloat16), - C: T.Tensor((M, N), T.bfloat16), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) - B_shared = T.alloc_shared((block_K, block_N), T.bfloat16) - C_local = T.alloc_fragment((block_M, block_N), T.float32) - - T.clear(C_local) - - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local) - - T.copy(C_local, C[by * block_M, bx * block_N]) - - _kernel_cache[key] = bf16_gemm - return bf16_gemm - - -# --------------------------------------------------------------------------- -# Grouped expert GEMM with FP4 dequantization -# --------------------------------------------------------------------------- - -def grouped_gemm_fp4( - x_bf16: torch.Tensor, # (total_tokens, K_dim) bfloat16 - weights_fp4: torch.Tensor, # (E, N, K//2) int8 packed E2M1 - scales_ue4m3: torch.Tensor, # (E, N, K//16) float8_e4m3fn - topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32 - topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32 -) -> torch.Tensor: - """Segmented grouped expert GEMM: dequantize FP4→BF16, per-expert GEMM. - - Strategy: - 1. Sort tokens by expert assignment - 2. For each expert, dequantize its weight to BF16 (cached) - 3. Run batched BF16 GEMM using TileLang-compiled kernels - 4. Scatter results back with routing weights - """ - num_tokens, K_dim = x_bf16.shape - E, N, K_half = weights_fp4.shape - K = K_half * 2 - assert K == K_dim, f"Activation K={K_dim} doesn't match weight K={K}" - top_k = topk_ids.shape[1] - device = x_bf16.device - - output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=device) - - # Pre-compute expert weight dequantization (cache for repeated use) - # For 32 experts, this is manageable - w_bf16_cache = {} - for e in range(E): - w_bf16_cache[e] = unpack_e2m1_to_bf16(weights_fp4[e], scales_ue4m3[e]) # (N, K) - - # Process per expert - for e in range(E): - # Find all (token, k_idx) pairs for this expert - mask = (topk_ids == e) # (num_tokens, top_k) - if not mask.any(): - continue - - w_bf16 = w_bf16_cache[e] # (N, K) - - # Collect tokens for this expert across all top-k slots - for k_idx in range(top_k): - token_mask = mask[:, k_idx] - if not token_mask.any(): - continue - token_indices = token_mask.nonzero(as_tuple=True)[0] - - # Gather activations - x_sub = x_bf16[token_indices] # (n, K) - - # BF16 GEMM: (n, K) @ (N, K).T → (n, N) - result = torch.nn.functional.linear(x_sub, w_bf16) - - # Weighted scatter-add - weights = topk_weights[token_indices, k_idx].unsqueeze(-1) - output[token_indices] += result * weights - - return output - - -# --------------------------------------------------------------------------- -# Convenience: grouped GEMM with uint32 packed scales -# --------------------------------------------------------------------------- - -def grouped_gemm_fp4_packed_sf( - x_bf16: torch.Tensor, - weights_fp4: torch.Tensor, # (E, N, K//2) int8 - scales_packed: torch.Tensor, # (E, N, sf_k_groups) uint32 packed UE4M3 - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, -) -> torch.Tensor: - """Same as grouped_gemm_fp4 but unpacks uint32 packed UE4M3 scales first.""" - scales_fp8 = unpack_ue4m3_u32(scales_packed) - return grouped_gemm_fp4(x_bf16, weights_fp4, scales_fp8, topk_ids, topk_weights) diff --git a/src/nvfp4_megamoe_kernel/tilelang_nvfp4_gemm.py b/src/nvfp4_megamoe_kernel/tilelang_nvfp4_gemm.py deleted file mode 100644 index 0a145d60..00000000 --- a/src/nvfp4_megamoe_kernel/tilelang_nvfp4_gemm.py +++ /dev/null @@ -1,599 +0,0 @@ -""" -Native NVFP4 Block-Scaled GEMM using tcgen05.mma kind::mxf8f6f4.block_scale. - -This module provides the native NVFP4 tensor core GEMM for Blackwell (SM100). -It uses the mxf8f6f4.block_scale PTX instruction which natively performs -E2M1 × E2M1 multiplication with UE4M3 block-16 scaling in tensor cores. - -Architecture: - - A: E2M1 packed (int8, 2 values per byte) in global → SMEM via TMA - - B: E2M1 packed (int8, 2 values per byte) in global → SMEM via TMA - - SFA: UE4M3 (float8_e4m3fn) in global → SMEM via TMA → TMEM via tcgen05.ld - - SFB: UE4M3 (float8_e4m3fn) in global → SMEM via TMA → TMEM via tcgen05.ld - - C: accumulated in TMEM, stored to global memory - -The key PTX instruction: - tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [tmem_c], - desc_a, desc_b, idescE, [tsfa_addr], [tsfb_addr], pred; - -This is the native hardware path for NVFP4 block-scaled MMA on Blackwell. -No dequantization is performed — the hardware does E2M1×E2M1 with UE4M3 -block scaling natively in the tensor cores. - -Implementation Strategy: - We compile the CUDA kernel at runtime using torch.utils.cpp_extension.load. - The CUDA kernel uses inline PTX for the mxf8f6f4.block_scale instruction - and TMA for efficient global→SMEM transfers. - - For the MoE (Mixture of Experts) use case, we support grouped GEMM where - each expert has its own weight matrix and scale factors, and tokens are - routed to specific experts via top-k indices. -""" - -import os -import time -import torch -import tempfile -import subprocess -import hashlib -from typing import Optional - -# DeepSeek-V4-Pro dimensions -HIDDEN = 7168 -INTERMEDIATE = 3072 -NUM_EXPERTS = 256 -NUM_RANKS = 8 -NUM_TOPK = 6 - -# Block sizes for the GEMM tiling -BLOCK_M = 128 -BLOCK_N = 128 -BLOCK_K = 64 # For f8f6f4, atom_k=32 elements = 16 bytes packed; we use 64 for double buffering - -MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0")) - -# --------------------------------------------------------------------------- -# CUDA kernel source -# --------------------------------------------------------------------------- - -NVFP4_BLOCKSCALED_GEMM_CUDA = r""" -#include -#include -#include -#include -#include - -// PTX instructions for Blackwell tcgen05 MMA with block scaling -// We use inline PTX for mxf8f6f4.block_scale which is not yet in cuda.h - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 -#define TCGEN05_BLOCKSCALED_ENABLED 1 -#else -#define TCGEN05_BLOCKSCALED_ENABLED 0 -#endif - -// Helper: initialize tcgen05 SMEM descriptor (64-bit) -// Matches the layout from TileLang's common.h initialize_tcgen05_descriptor -__device__ __forceinline__ -uint64_t make_tcgen05_descriptor( - uint32_t smem_addr, // shared memory base address - uint32_t leading_bytes, // bytes between consecutive rows/cols in the leading dim - uint32_t stride_bytes, // bytes between tiles in the stride dim - uint32_t base_offset, // offset within the 256B swizzle block (0-7) - uint32_t leading_abs, // 1 if leading offset is absolute, 0 if relative - uint32_t swizzle_mode // 0=none, 1=32B, 2=64B, 3=128B, 4=128B_base32B -) { - uint32_t lo = (smem_addr >> 4) | (leading_bytes << 16); - uint32_t hi = stride_bytes | (1u << 14) // version = 1 - | ((base_offset & 0x7) << 17) - | (leading_abs << 20) - | ((swizzle_mode & 0x7) << 29); - return (static_cast(hi) << 32) | lo; -} - -// ============================================================================ -// Native NVFP4 Block-Scaled GEMM Kernel -// ============================================================================ -// -// Computes C = A @ B where: -// A is E2M1 packed (int8, 2 vals/byte) with UE4M3 block-16 scales (SFA) -// B is E2M1 packed (int8, 2 vals/byte) with UE4M3 block-16 scales (SFB) -// C is float32 (or bfloat16) -// -// Uses tcgen05.mma.kind::mxf8f6f4.block_scale PTX instruction. -// This performs E2M1 × E2M1 with UE4M3 block scaling natively in tensor cores. -// -// For MoE: each CTA handles one (expert, m_tile, n_tile) work item. -// ============================================================================ - -template -__global__ void __launch_bounds__(128) -nvfp4_blockscaled_gemm_kernel( - // A: (M, K_packed) int8 — K_packed = K/2 (2 E2M1 per byte) - const int8_t* __restrict__ A, - // SFA: (M, K_sf) float8_e4m3fn — K_sf = K/16 (one scale per 16 elements) - const __nv_fp8_e4m3* __restrict__ SFA, - // B: (N, K_packed) int8 — K-major layout - const int8_t* __restrict__ B, - // SFB: (N, K_sf) float8_e4m3fn - const __nv_fp8_e4m3* __restrict__ SFB, - // C: (M, N) float32 output - float* __restrict__ C, - // Dimensions - int M, int N, int K, - // Strides - int64_t stride_a_m, int64_t stride_a_k, // A row/col strides in elements - int64_t stride_sfa_m, int64_t stride_sfa_k, // SFA strides - int64_t stride_b_n, int64_t stride_b_k, // B row/col strides in elements - int64_t stride_sfb_n, int64_t stride_sfb_k, // SFB strides - int64_t stride_c_m, int64_t stride_c_n // C strides -) { - // For SM100+, we would use tcgen05.mma.kind::mxf8f6f4.block_scale - // However, the full implementation requires: - // 1. TMA descriptors for global→SMEM async copies - // 2. tcgen05.ld for SMEM→TMEM scale factor transfer - // 3. TMEM allocation for accumulators and scale factors - // 4. tcgen05.mma.kind::mxf8f6f4.block_scale PTX - // 5. TMEM→global store for results - // - // The TMA descriptor setup requires CUDA runtime APIs that are only - // available in CUDA 13.0+ driver. For now, we implement a fallback - // that does the dequantize+GEMM on tensor cores with BF16 MMA, - // and document the native path for when TMA descriptor APIs are stable. - -#if TCGEN05_BLOCKSCALED_ENABLED && 0 // Disabled until TMA APIs are stable - // Native f8f6f4 block-scaled MMA path - // This code path will be enabled once CUDA 13.0 TMA descriptor APIs - // are available in the PyTorch CUDA extension build system. - - // ... (TMA + tcgen05.mma.kind::mxf8f6f4.block_scale PTX) ... - -#else - // Fallback: Dequantize E2M1+UE4M3 → BF16, then BF16 GEMM - // This uses tcgen05.mma.kind::f16 which is native BF16 tensor core MMA. - // The dequantization is done per-tile in shared memory. - - const int tid = threadIdx.x; - const int bx = blockIdx.x; - const int by = blockIdx.y; - - const int m_start = by * BLOCK_M; - const int n_start = bx * BLOCK_N; - - if (m_start >= M || n_start >= N) return; - - const int K_packed = K / 2; // E2M1 packed: 2 per byte - const int K_sf = K / 16; // UE4M3 block16: 1 scale per 16 elements - const int BLOCK_K_packed = BLOCK_K_ELEMS / 2; - const int BLOCK_K_sf = BLOCK_K_ELEMS / 16; - - // Shared memory for A (dequantized to BF16), B (dequantized to BF16) - extern __shared__ char smem[]; - __nv_bfloat16* sA = reinterpret_cast<__nv_bfloat16*>(smem); - __nv_bfloat16* sB = reinterpret_cast<__nv_bfloat16*>(smem + BLOCK_M * BLOCK_K_ELEMS * sizeof(__nv_bfloat16)); - - // Register fragment for accumulator - float accum[BLOCK_M * BLOCK_N / 128] = {0}; // Per-thread accumulator - - // For simplicity, use wmma-style accumulation - // Each thread in the CTA cooperates on the tile - const int m_valid = min(BLOCK_M, M - m_start); - const int n_valid = min(BLOCK_N, N - n_start); - - // Process K in tiles - for (int k_start = 0; k_start < K; k_start += BLOCK_K_ELEMS) { - const int k_valid = min(BLOCK_K_ELEMS, K - k_start); - const int k_packed_valid = k_valid / 2; - const int k_sf_valid = k_valid / 16; - - // Load and dequantize A tile: (BLOCK_M, BLOCK_K) BF16 - for (int idx = tid; idx < BLOCK_M * BLOCK_K_ELEMS; idx += 128) { - int m_local = idx / BLOCK_K_ELEMS; - int k_local = idx % BLOCK_K_ELEMS; - int m_global = m_start + m_local; - int k_global = k_start + k_local; - - if (m_global < M && k_global < K) { - // Load E2M1 value - int8_t packed = A[m_global * stride_a_k + k_global / 2]; - int e2m1_val = (k_global & 1) ? (packed >> 4) & 0xF : packed & 0xF; - - // E2M1 to float: sign * 2^(exp-2) * (1 + mant*0.5) - float sign = (e2m1_val >> 3) ? -1.0f : 1.0f; - int exp_field = (e2m1_val >> 1) & 0x3; - float mant = (e2m1_val & 1) * 0.5f; - float val = sign * powf(2.0f, exp_field - 2.0f) * (1.0f + mant); - if (exp_field == 0 && !(e2m1_val & 1)) val = 0.0f; - - // Apply UE4M3 block scale - int sf_idx = k_global / 16; - __nv_fp8_e4m3 sf = SFA[m_global * stride_sfa_k + sf_idx]; - float sf_val = __nv_fp8_e4m3_to_float(sf); - - sA[m_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(val * sf_val); - } else { - sA[m_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(0.0f); - } - } - - // Load and dequantize B tile: (BLOCK_N, BLOCK_K) BF16 (K-major B) - for (int idx = tid; idx < BLOCK_N * BLOCK_K_ELEMS; idx += 128) { - int n_local = idx / BLOCK_K_ELEMS; - int k_local = idx % BLOCK_K_ELEMS; - int n_global = n_start + n_local; - int k_global = k_start + k_local; - - if (n_global < N && k_global < K) { - int8_t packed = B[n_global * stride_b_k + k_global / 2]; - int e2m1_val = (k_global & 1) ? (packed >> 4) & 0xF : packed & 0xF; - - float sign = (e2m1_val >> 3) ? -1.0f : 1.0f; - int exp_field = (e2m1_val >> 1) & 0x3; - float mant = (e2m1_val & 1) * 0.5f; - float val = sign * powf(2.0f, exp_field - 2.0f) * (1.0f + mant); - if (exp_field == 0 && !(e2m1_val & 1)) val = 0.0f; - - int sf_idx = k_global / 16; - __nv_fp8_e4m3 sf = SFB[n_global * stride_sfb_k + sf_idx]; - float sf_val = __nv_fp8_e4m3_to_float(sf); - - sB[n_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(val * sf_val); - } else { - sB[n_local * BLOCK_K_ELEMS + k_local] = __float2bfloat16(0.0f); - } - } - - __syncthreads(); - - // BF16 GEMM accumulation: C += A @ B^T - // Simple per-thread row-major accumulation (not using tensor cores in this fallback) - // In the native path, tcgen05.mma handles this natively - for (int m_local = 0; m_local < m_valid; m_local++) { - for (int n_local = tid; n_local < n_valid; n_local += 128) { - float sum = 0.0f; - for (int k_local = 0; k_local < k_valid; k_local++) { - sum += __bfloat162float(sA[m_local * BLOCK_K_ELEMS + k_local]) - * __bfloat162float(sB[n_local * BLOCK_K_ELEMS + k_local]); - } - // Atomic add to output - int m_global = m_start + m_local; - int n_global = n_start + n_local; - atomicAdd(&C[m_global * stride_c_n + n_global], sum); - } - } - - __syncthreads(); - } -#endif -} - - -// ============================================================================ -// Native NVFP4 Block-Scaled GEMM — Full Pipeline (SM100 native path) -// ============================================================================ -// -// This kernel uses the full native pipeline: -// 1. TMA async copy: E2M1 data → SMEM -// 2. TMA async copy: UE4M3 scales → SMEM -// 3. tcgen05.ld: SMEM scales → TMEM -// 4. tcgen05.mma.kind::mxf8f6f4.block_scale: native block-scaled MMA -// 5. TMEM store: results → global memory -// -// Requires CUDA 13.0+ and SM100 (Blackwell) hardware. -// ============================================================================ - -template -__global__ void __launch_bounds__(128, 1) -nvfp4_blockscaled_gemm_native_kernel( - const int8_t* __restrict__ A_packed, // (M, K/2) E2M1 packed - const uint8_t* __restrict__ SFA_packed, // (M, K/16) UE4M3 scales as uint8 - const int8_t* __restrict__ B_packed, // (N, K/2) E2M1 packed - const uint8_t* __restrict__ SFB_packed, // (N, K/16) UE4M3 scales as uint8 - float* __restrict__ C_out, // (M, N) float32 output - int M, int N, int K_total, - int64_t stride_a, int64_t stride_sfa, - int64_t stride_b, int64_t stride_sfb, - int64_t stride_c -) { - // The native mxf8f6f4.block_scale path - // This requires: - // - TMA tensor map creation (cuTensorMapEncodeTiled) for A, B, SFA, SFB - // - Shared memory with proper swizzle layout for tcgen05 descriptors - // - TMEM allocation for accumulators (C) and scale factors (SFA, SFB) - // - tcgen05.ld for scale factor SMEM→TMEM transfer - // - tcgen05.mma.kind::mxf8f6f4.block_scale for native MMA - // - tcgen05.st for TMEM→global result store - // - // The full implementation requires CUDA 13.0 driver support for: - // - cuTensorMapEncodeTiled with sub-byte types - // - TMEM allocation/deallocation intrinsics - // - tcgen05.ld/st intrinsics - // - // For now, we delegate to the fallback kernel. - // The native path will be enabled in a follow-up when the build - // system supports CUDA 13.0 headers. - - // This should never be called — we use the fallback path - assert(0 && "Native path not yet compiled — use fallback"); -} - - -// ============================================================================ -// PyTorch bindings -// ============================================================================ - -torch::Tensor nvfp4_blockscaled_gemm_forward( - torch::Tensor A_packed, // (M, K/2) int8 — E2M1 packed - torch::Tensor SFA, // (M, K/16) float8_e4m3fn or uint8 — UE4M3 block16 scales - torch::Tensor B_packed, // (N, K/2) int8 — E2M1 packed - torch::Tensor SFB, // (N, K/16) float8_e4m3fn or uint8 — UE4M3 block16 scales - int64_t M, int64_t N, int64_t K -) { - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(A_packed.device()); - auto C = torch::zeros({M, N}, options); - - const int BLOCK_M = 128; - const int BLOCK_N = 128; - const int BLOCK_K = 128; - - dim3 grid((N + BLOCK_N - 1) / BLOCK_N, (M + BLOCK_M - 1) / BLOCK_M); - dim3 block(128); - - // Calculate shared memory size: A + B in BF16 - int smem_size = (BLOCK_M * BLOCK_K + BLOCK_N * BLOCK_K) * sizeof(__nv_bfloat16); - - auto stream = c10::cuda::getCurrentCUDAStream(); - - nvfp4_blockscaled_gemm_kernel<<>>( - A_packed.data_ptr(), - reinterpret_cast(SFA.data_ptr()), - B_packed.data_ptr(), - reinterpret_cast(SFB.data_ptr()), - C.data_ptr(), - M, N, K, - K / 2, 1, // stride_a_m, stride_a_k - K / 16, 1, // stride_sfa_m, stride_sfa_k - K / 2, 1, // stride_b_n, stride_b_k - K / 16, 1, // stride_sfb_n, stride_sfb_k - N, 1 // stride_c_m, stride_c_n - ); - - return C; -} - -TORCH_LIBRARY(nvfp4_blockscaled, m) { - m.def("gemm_forward", &nvfp4_blockscaled_gemm_forward); -} -""" - -# --------------------------------------------------------------------------- -# Kernel compilation and caching -# --------------------------------------------------------------------------- - -_compiled_ext = None -_ext_lock = None - - -def _get_compiled_extension(): - """Compile and cache the CUDA extension.""" - global _compiled_ext - if _compiled_ext is not None: - return _compiled_ext - - import torch.utils.cpp_extension as cpp_ext - - with tempfile.TemporaryDirectory() as tmpdir: - cu_path = os.path.join(tmpdir, "nvfp4_blockscaled_gemm.cu") - with open(cu_path, "w") as f: - f.write(NVFP4_BLOCKSCALED_GEMM_CUDA) - - ext = cpp_ext.load( - name="nvfp4_blockscaled", - sources=[cu_path], - extra_cuda_cflags=[ - "-gencode=arch=compute_100a,code=sm_100a", - "--expt-relaxed-constexpr", - "-DNVFP4_BLOCKSCALED_ENABLED=1", - ], - extra_cflags=["-O2"], - verbose=MEGA_MOE_DEBUG, - ) - - _compiled_ext = ext - return ext - - -# --------------------------------------------------------------------------- -# Native NVFP4 GEMM API -# --------------------------------------------------------------------------- - -def nvfp4_blockscaled_gemm( - A_packed: torch.Tensor, # (M, K//2) int8 — E2M1 packed, K-major - A_scales: torch.Tensor, # (M, K//16) float8_e4m3fn — UE4M3 block16 scales - B_packed: torch.Tensor, # (N, K//2) int8 — E2M1 packed, K-major - B_scales: torch.Tensor, # (N, K//16) float8_e4m3fn — UE4M3 block16 scales -) -> torch.Tensor: - """Native NVFP4 block-scaled GEMM: C = A @ B^T. - - A is (M, K//2) int8 E2M1 packed with (M, K//16) UE4M3 scales. - B is (N, K//2) int8 E2M1 packed with (N, K//16) UE4M3 scales. - C is (M, N) float32. - - Uses the native mxf8f6f4.block_scale tensor core instruction on Blackwell. - Falls back to dequantize+BF16-GEMM on non-Blackwell hardware. - """ - M = A_packed.shape[0] - K_half = A_packed.shape[1] - K = K_half * 2 - N = B_packed.shape[0] - - assert A_packed.dtype == torch.int8, f"A must be int8, got {A_packed.dtype}" - assert B_packed.dtype == torch.int8, f"B must be int8, got {B_packed.dtype}" - assert A_packed.is_cuda and B_packed.is_cuda, "Tensors must be on CUDA" - - # Try native path - try: - ext = _get_compiled_extension() - # Ensure scales are uint8 view of float8_e4m3fn - if A_scales.dtype == torch.float8_e4m3fn: - A_sf_u8 = A_scales.view(torch.uint8) - else: - A_sf_u8 = A_scales - if B_scales.dtype == torch.float8_e4m3fn: - B_sf_u8 = B_scales.view(torch.uint8) - else: - B_sf_u8 = B_scales - - return ext.gemm_forward(A_packed, A_sf_u8, B_packed, B_sf_u8, M, N, K) - except Exception as e: - if MEGA_MOE_DEBUG: - print(f"[nvfp4_gemm] Native kernel failed, using dequant fallback: {e}") - # Fallback: dequantize and use torch.matmul - from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16 - A_bf16 = unpack_e2m1_to_bf16(A_packed, A_scales) # (M, K) - B_bf16 = unpack_e2m1_to_bf16(B_packed, B_scales) # (N, K) - return torch.matmul(A_bf16.to(torch.float32), B_bf16.to(torch.float32).t()) - - -# --------------------------------------------------------------------------- -# MoE Grouped GEMM with native NVFP4 block-scaled MMA -# --------------------------------------------------------------------------- - -def grouped_gemm_nvfp4_native( - x_packed: torch.Tensor, # (num_tokens, K//2) int8 — E2M1 packed - x_scales: torch.Tensor, # (num_tokens, K//16) UE4M3 scales - weights: torch.Tensor, # (E, N, K//2) int8 — per-expert E2M1 weights - weight_scales: torch.Tensor, # (E, N, K//16) UE4M3 per-expert scales - topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32 - topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32 -) -> torch.Tensor: - """Segmented grouped expert GEMM with native NVFP4 block-scaled MMA. - - For each expert, runs the native NVFP4 GEMM on tokens routed to it. - Results are scattered back with routing weights. - - Args: - x_packed: Packed E2M1 activations (num_tokens, K//2) - x_scales: UE4M3 block16 scales (num_tokens, K//16) - weights: Per-expert E2M1 weights (E, N, K//2) - weight_scales: Per-expert UE4M3 scales (E, N, K//16) - topk_ids: Expert assignments (num_tokens, NUM_TOPK) - topk_weights: Routing weights (num_tokens, NUM_TOPK) - - Returns: - (num_tokens, N) bfloat16 output - """ - num_tokens = x_packed.shape[0] - K_half = x_packed.shape[1] - K = K_half * 2 - E = weights.shape[0] - N = weights.shape[1] - top_k = topk_ids.shape[1] - device = x_packed.device - - output = torch.zeros(num_tokens, N, dtype=torch.float32, device=device) - - # Process per expert - for e in range(E): - mask = (topk_ids == e) # (num_tokens, top_k) - if not mask.any(): - continue - - for k_idx in range(top_k): - token_mask = mask[:, k_idx] - if not token_mask.any(): - continue - token_indices = token_mask.nonzero(as_tuple=True)[0] - - # Gather activations for this expert - x_sub_packed = x_packed[token_indices] # (n, K//2) - x_sub_scales = x_scales[token_indices] # (n, K//16) - w_packed = weights[e] # (N, K//2) - w_scales = weight_scales[e] # (N, K//16) - - # Native NVFP4 GEMM: (n, K) @ (N, K)^T → (n, N) - result = nvfp4_blockscaled_gemm( - x_sub_packed, x_sub_scales, - w_packed, w_scales, - ) # (n, N) float32 - - # Weighted scatter-add - weights_f32 = topk_weights[token_indices, k_idx].unsqueeze(-1) - output[token_indices] += result * weights_f32 - - return output.to(torch.bfloat16) - - -# --------------------------------------------------------------------------- -# Convenience wrappers for uint32 packed scales -# --------------------------------------------------------------------------- - -def grouped_gemm_nvfp4_packed_sf( - x_packed: torch.Tensor, # (num_tokens, K//2) int8 - x_sf_packed: torch.Tensor, # (num_tokens, sf_groups) uint32 packed UE4M3 - weights: torch.Tensor, # (E, N, K//2) int8 - weight_sf: torch.Tensor, # (E, N, sf_groups) uint32 packed UE4M3 - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, -) -> torch.Tensor: - """Grouped GEMM with uint32 packed UE4M3 scales.""" - from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_ue4m3_u32 - - x_sf_fp8 = unpack_ue4m3_u32(x_sf_packed) if x_sf_packed.dtype == torch.uint32 else x_sf_packed - w_sf_fp8 = unpack_ue4m3_u32(weight_sf) if weight_sf.dtype == torch.uint32 else weight_sf - - return grouped_gemm_nvfp4_native( - x_packed, x_sf_fp8, - weights, w_sf_fp8, - topk_ids, topk_weights, - ) - - -# --------------------------------------------------------------------------- -# TileLang-based NVFP4 GEMM (using T.gemm with float8_e4m3) -# --------------------------------------------------------------------------- - -_tilelang_kernel_cache = {} - - -def _make_tilelang_nvfp4_gemm(M, N, K_packed, block_M=128, block_N=128, block_K=64): - """Build a TileLang GEMM kernel using float8_e4m3 (f8f6f4) tensor cores. - - This uses TileLang's T.gemm() with float8_e4m3 dtype, which lowers to - tcgen05.mma.kind::f8f4 on Blackwell. Note: this path does NOT apply - UE4M3 block scaling natively — it does E2M1 × E2M1 without scales. - For proper NVFP4 with block scaling, use nvfp4_blockscaled_gemm(). - - The TileLang path is kept for experimentation and benchmarking. - """ - key = (M, N, K_packed, block_M, block_N, block_K) - if key in _tilelang_kernel_cache: - return _tilelang_kernel_cache[key] - - import tilelang - import tilelang.language as T - - K = K_packed * 2 # Unpacked element count - - @tilelang.jit(out_idx=[2]) - def fp4_gemm( - A: T.Tensor((M, K_packed), "float8_e4m3"), - B: T.Tensor((K_packed, N), "float8_e4m3"), - C: T.Tensor((M, N), T.float32), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), "float8_e4m3") - B_shared = T.alloc_shared((block_K, block_N), "float8_e4m3") - C_local = T.alloc_fragment((block_M, block_N), T.float32) - - T.clear(C_local) - - for k in T.Pipelined(T.ceildiv(K_packed, block_K), num_stages=2): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local, num_elems_per_byte=2) - - T.copy(C_local, C[by * block_M, bx * block_N]) - - _tilelang_kernel_cache[key] = fp4_gemm - return fp4_gemm diff --git a/src/nvfp4_megamoe_kernel/weight_transform.py b/src/nvfp4_megamoe_kernel/weight_transform.py index 02eefe70..1d61f9bc 100644 --- a/src/nvfp4_megamoe_kernel/weight_transform.py +++ b/src/nvfp4_megamoe_kernel/weight_transform.py @@ -1,10 +1,10 @@ """ -NVFP4 Weight Transformation for TileLang mega_moe kernel. +NVFP4 Weight Transformation for CUTLASS mega_moe kernel. Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float32 global scale) -into the TMA-aligned format expected by the block-scaled GEMM kernel: -- Packed FP4 weights (uint8, K-major) -- Packed UE4M3 scales (uint32, TMA-aligned UTCCP layout) +into the format expected by the CUTLASS block-scaled GEMM kernel: +- Packed FP4 weights (int8, K-major) +- UE4M3 block scales (float8_e4m3fn, row-major — CUTLASS SF remap handles layout) This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe. @@ -70,13 +70,13 @@ def _interleave_l1_weights(weight: torch.Tensor) -> torch.Tensor: return interleaved.contiguous() -def transform_nvfp4_weights_for_tilelang( +def transform_nvfp4_weights_for_mega_moe( l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale) l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale) l1_weight_scale_2: torch.Tensor = None, # float32 global scale for L1 l2_weight_scale_2: torch.Tensor = None, # float32 global scale for L2 ) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: - """Transform NVFP4 weights for the TileLang block-scaled GEMM. + """Transform NVFP4 weights for the CUTLASS block-scaled GEMM. Matches the call signature from nightly vLLM deepseek_v4.py finalize_weights. @@ -113,7 +113,3 @@ def transform_nvfp4_weights_for_tilelang( l2_weight_out = l2_weight.contiguous() return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out) - - -# Alias for drop-in replacement -transform_nvfp4_weights_for_mega_moe = transform_nvfp4_weights_for_tilelang diff --git a/src/quantize.mojo b/src/quantize.mojo deleted file mode 100644 index 71a7126f..00000000 --- a/src/quantize.mojo +++ /dev/null @@ -1,95 +0,0 @@ -""" -NVFP4 quantization utilities — E2M1 packing and UE4M3 scale handling. -Core math layer for the NVFP4 mega_moe kernel rewrite. -""" - -# E2M1 magnitude lookup table (positive values only) -# Index 0-7 maps to: 0, 0.5, 1, 1.5, 2, 3, 4, 6 -def e2m1_magnitude(index: Int) -> Float64: - if index == 0: return 0.0 - if index == 1: return 0.5 - if index == 2: return 1.0 - if index == 3: return 1.5 - if index == 4: return 2.0 - if index == 5: return 3.0 - if index == 6: return 4.0 - if index == 7: return 6.0 - return 0.0 - - -def quantize_e2m1(value: Float64) -> UInt8: - """Quantize a float64 value to E2M1 (4-bit), returning the 4-bit nibble with sign.""" - var sign = 0 - var abs_val = value - if value < 0.0: - sign = 1 - abs_val = -value - - # Find best E2M1 match - var best_idx = 0 - var best_err = abs_val # error for idx=0 - - for i in range(1, 8): - mag = e2m1_magnitude(i) - err = abs(abs_val - mag) - if err < best_err: - best_err = err - best_idx = i - - return (sign << 3) | best_idx - - -def unpack_e2m1(packed: UInt8, idx: Int) -> Float64: - """Unpack one E2M1 value from a packed byte. - idx=0 -> low nibble, idx=1 -> high nibble. - """ - nibble: UInt8 - if idx == 0: - nibble = packed & 0x0F - else: - nibble = (packed >> 4) & 0x0F # keep sign bit - - sign = (nibble >> 3) & 1 - mag_idx = nibble & 0x07 - magnitude = e2m1_magnitude(Int(mag_idx)) - - if sign: - return -magnitude - return magnitude - - -def dequantize_nvfp4_weight( - packed_weight: UInt8, - block_scale: Float64, - group_offset: Int, -) -> Float64: - """Dequantize a single NVFP4 weight element. - weight = E2M1_magnitude * block_scale - (global_scale is already folded into block_scale) - """ - e2m1_value = unpack_e2m1(packed_weight, group_offset) - return e2m1_value * block_scale - - -def main() raises: - # Test E2M1 quantization round-trip - print("E2M1 quantization test:") - for val in [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]: - packed = quantize_e2m1(val) - unpacked = unpack_e2m1(packed, 0) - print(" ", val, " -> E2M1 -> ", unpacked) - - # Test packed byte (two E2M1 values) - print("\nPacked byte test:") - lo = 1.5 - hi = -3.0 - packed = (quantize_e2m1(hi) << 4) | quantize_e2m1(lo) - print(" lo=", lo, " hi=", hi, " packed=", packed) - print(" unpack lo=", unpack_e2m1(packed, 0), " unpack hi=", unpack_e2m1(packed, 1)) - - # Test NVFP4 dequantization - print("\nNVFP4 dequantization test:") - packed_w = UInt8(0x36) # low=6.0, high=3.0 - scale = 0.5 - print(" packed=0x36, scale=0.5, lo=", dequantize_nvfp4_weight(packed_w, scale, 0)) - print(" packed=0x36, scale=0.5, hi=", dequantize_nvfp4_weight(packed_w, scale, 1)) diff --git a/src/staging.py b/src/staging.py deleted file mode 100644 index 42a9ec4d..00000000 --- a/src/staging.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -NVFP4 Activation Quantization — BF16 → packed FP4 + UE4M3 block16 scales - -Port of patches/staging_kernel.py to TileLang. - -This quantizes BF16 hidden states to: -- Packed FP4 (E2M1, 2 values per uint8 byte) -- UE4M3 block16 scales (4 values per uint32 word, group_size=16) -""" - -import torch -import tilelang -import tilelang.language as T - - -@tilelang.jit -def fp4_quant_kernel( - X, # (M, K) bfloat16 input - block_M=128, - block_K=128, - group_size=16, # NVFP4 UE4M3 group_size -): - """Quantize BF16 → packed FP4 (E2M1) + UE4M3 block16 scales. - - Output: - - X_FP4: (M, K//2) uint8 packed E2M1 - - X_SF: (M, K//group_size//4) uint32 packed UE4M3 scales - """ - M, K = T.const("M, K") - X: T.Tensor[[M, K], "bfloat16"] - - X_FP4 = T.empty((M, K // 2), "uint8") - X_SF = T.empty((M, K // (group_size * 4)), "uint32") - - with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(K, block_K), threads=128) as (bx, by): - # Each block quantizes a (block_M, block_K) tile - tx = T.get_thread_binding() - - row = bx * block_M + tx // (block_K // 2) - col = tx % (block_K // 2) - - if row < M and col < K // 2: - # Load 2 BF16 values - val_lo = X[row, col * 2] - val_hi = X[row, col * 2 + 1] - - # Quantize to E2M1 (4-bit each, pack 2 per byte) - # ... (E2M1 quantization logic) - packed = T.uint8(0) # placeholder - X_FP4[row, col] = packed - - # Compute block scale for group of 16 elements - # ... (UE4M3 scale computation) - - return X_FP4, X_SF - - -def fp4_quant_reference(x_bf16, group_size=16): - """Reference FP4 quantization using PyTorch + DeepGEMM. - - Used for correctness verification of the TileLang kernel. - """ - from deep_gemm import per_token_cast_to_fp4 - x_fp4, x_sf = per_token_cast_to_fp4(x_bf16) - return x_fp4, x_sf diff --git a/src/weight_transform.py b/src/weight_transform.py deleted file mode 100644 index 528dbdb5..00000000 --- a/src/weight_transform.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -NVFP4 Weight Transformation for TileLang mega_moe kernel. - -Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float32 global scale) -into the TMA-aligned format expected by the block-scaled GEMM kernel: -- Packed FP4 weights (uint8, K-major) -- Packed UE4M3 scales (uint32, TMA-aligned UTCCP layout) - -This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe. -""" - -import torch -import math - - -def fold_global_scale( - weight_scale: torch.Tensor, # (N, K//16) float8_e4m3fn - weight_scale_2: torch.Tensor, # (num_logical,) or scalar float32 - logical_widths: list[int] = None, # per-logical-weight row counts -) -> torch.Tensor: - """Fold global scale into block scales: UE4M3 * FP32 → UE4M3 → float32. - - Returns: (N, K//16) float32 folded block scales. - """ - sf_f32 = weight_scale.to(torch.float32) - - if weight_scale_2.numel() == 1: - sf_f32 = sf_f32 * weight_scale_2.to(torch.float32) - elif weight_scale_2.numel() > 1 and logical_widths is not None: - # Per-logical-weight global scale (e.g., gate_up_proj has 2) - expanded = [] - for i, w in enumerate(logical_widths): - if i < len(weight_scale_2): - expanded.append(weight_scale_2[i].flatten()[0].expand(w)) - global_scale = torch.cat(expanded).to(torch.float32).unsqueeze(1) - sf_f32 = sf_f32 * global_scale - else: - sf_f32 = sf_f32 * weight_scale_2.max().to(torch.float32) - - return sf_f32 - - -def pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor: - """Pack 4 UE4M3 (float8_e4m3fn) values into one uint32. - - Input: (..., K//16) float8_e4m3fn - Output: (..., K//64) uint32 (4 values packed per word) - """ - # View as raw bytes - sf_u8 = sf.view(torch.uint8) # (..., K//16) uint8 - shape = sf_u8.shape - assert shape[-1] % 4 == 0, f"Last dim {shape[-1]} not divisible by 4" - - # Pack 4 consecutive uint8 values into one uint32 - packed = (sf_u8[..., 0::4].to(torch.int32) | - (sf_u8[..., 1::4].to(torch.int32) << 8) | - (sf_u8[..., 2::4].to(torch.int32) << 16) | - (sf_u8[..., 3::4].to(torch.int32) << 24)) - - return packed.contiguous() - - -def interleave_l1_weights( - weight: torch.Tensor, # (E, 2*INTER, K//2) int8, K-major -) -> torch.Tensor: - """Interleave L1 (gate_up) weights for 2CTA UMMA. - - The gate and up projections are interleaved in groups of 8 rows - to match the UMMA 2CTA schedule. - """ - E, N, K_half = weight.shape - assert N % 16 == 0, f"N={N} not divisible by 16" - - # Reshape to (E, N//16, 16, K_half) → interleave pairs of 8 - w = weight.view(E, N // 16, 16, K_half) - # Interleave: [gate_0..7, up_0..7, gate_8..15, up_8..15, ...] - # Actually: pairs of 8 rows from gate and up - w_gate = w[:, :, :8, :] - w_up = w[:, :, 8:16, :] - interleaved = torch.stack([w_gate, w_up], dim=3).reshape(E, N, K_half) - - return interleaved.contiguous() - - -def transform_nvfp4_weights_for_tilelang( - weight: torch.Tensor, # (E, N, K//2) int8, K-major packed FP4 - weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn UE4M3 - weight_scale_2: torch.Tensor, # (E, 1) or (E, 2) float32 global scale - N: int, # output dimension (2*INTER for L1, HIDDEN for L2) - K: int, # input dimension (HIDDEN for L1, INTER for L2) - gran_k: int = 16, # NVFP4 group_size - is_l1: bool = False, # True for gate_up_proj (needs interleaving) - logical_widths: list[int] = None, -) -> tuple[torch.Tensor, torch.Tensor]: - """Transform NVFP4 weights for the TileLang block-scaled GEMM. - - Returns: - transformed_weight: (E, N, K//2) int8, K-major, contiguously interleaved - transformed_sf: (E, N, K//64) uint32, packed UE4M3 with folded global scale - """ - device = weight.device - - # Step 1: Fold global scale into block scales - sf_folded = fold_global_scale(weight_scale, weight_scale_2, logical_widths) - - # Step 2: Clamp and convert back to UE4M3 - sf_clamped = sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) - - # Step 3: Pack UE4M3 into uint32 - sf_packed = pack_ue4m3_to_uint32(sf_clamped) # (E, N, K//64) - - # Step 4: Interleave L1 weights if needed - if is_l1: - transformed_weight = interleave_l1_weights(weight) - else: - transformed_weight = weight.contiguous() - - # Step 5: Ensure K-major layout (already K-major from checkpoint) - # The weight should be (E, N, K//2) with stride (N*K//2, K//2, 1) for K-major - - return transformed_weight, sf_packed - - -def main(): - """Test the weight transformation with dummy data.""" - E = 4 # small test - N = 64 - K = 128 - - device = "cpu" # No GPU on sandbox - weight = torch.randint(-128, 127, (E, N, K // 2), dtype=torch.int8, device=device) - weight_scale = torch.randn(E, N, K // 16, device=device).abs().to(torch.float8_e4m3fn) - weight_scale_2 = torch.ones(E, 2, dtype=torch.float32, device=device) - - transformed_w, transformed_sf = transform_nvfp4_weights_for_tilelang( - weight, weight_scale, weight_scale_2, N, K, is_l1=True, - logical_widths=[32, 32], - ) - - print(f"Weight: {transformed_w.shape} {transformed_w.dtype}") - print(f"SF: {transformed_sf.shape} {transformed_sf.dtype}") - print(f"Expected SF: (E={E}, N={N}, K//64={K//64})") - print("Weight transformation test passed!") - - -if __name__ == "__main__": - main()