diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..9f7983e8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +*.pyc +*.egg-info/ diff --git a/src/nvfp4_megamoe_kernel.egg-info/PKG-INFO b/src/nvfp4_megamoe_kernel.egg-info/PKG-INFO new file mode 100644 index 00000000..59be8beb --- /dev/null +++ b/src/nvfp4_megamoe_kernel.egg-info/PKG-INFO @@ -0,0 +1,7 @@ +Metadata-Version: 2.4 +Name: nvfp4-megamoe-kernel +Version: 0.1.0 +Summary: NVFP4 Mega MoE kernel for DeepSeek-V4-Pro on Blackwell (TileLang) +Requires-Python: >=3.10 +Requires-Dist: torch>=2.5 +Requires-Dist: tilelang>=0.1 diff --git a/src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt b/src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt new file mode 100644 index 00000000..f3641fa0 --- /dev/null +++ b/src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt @@ -0,0 +1,11 @@ +README.md +pyproject.toml +src/nvfp4_megamoe_kernel/__init__.py +src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +src/nvfp4_megamoe_kernel/symm_buffer.py +src/nvfp4_megamoe_kernel/weight_transform.py +src/nvfp4_megamoe_kernel.egg-info/PKG-INFO +src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt +src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt +src/nvfp4_megamoe_kernel.egg-info/requires.txt +src/nvfp4_megamoe_kernel.egg-info/top_level.txt \ No newline at end of file diff --git a/src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt b/src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/nvfp4_megamoe_kernel.egg-info/requires.txt b/src/nvfp4_megamoe_kernel.egg-info/requires.txt new file mode 100644 index 00000000..06d97fde --- /dev/null +++ b/src/nvfp4_megamoe_kernel.egg-info/requires.txt @@ -0,0 +1,2 @@ +torch>=2.5 +tilelang>=0.1 diff --git a/src/nvfp4_megamoe_kernel.egg-info/top_level.txt b/src/nvfp4_megamoe_kernel.egg-info/top_level.txt new file mode 100644 index 00000000..0c0c2376 --- /dev/null +++ b/src/nvfp4_megamoe_kernel.egg-info/top_level.txt @@ -0,0 +1 @@ +nvfp4_megamoe_kernel diff --git a/src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 00000000..71a6b920 Binary files /dev/null and b/src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc differ diff --git a/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_dequant.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_dequant.cpython-312.pyc new file mode 100644 index 00000000..f38c12b5 Binary files /dev/null and b/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_dequant.cpython-312.pyc differ diff --git a/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_mega_moe.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_mega_moe.cpython-312.pyc new file mode 100644 index 00000000..0ad6c213 Binary files /dev/null and b/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_mega_moe.cpython-312.pyc differ diff --git a/src/nvfp4_megamoe_kernel/__pycache__/symm_buffer.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/symm_buffer.cpython-312.pyc new file mode 100644 index 00000000..afea703e Binary files /dev/null and b/src/nvfp4_megamoe_kernel/__pycache__/symm_buffer.cpython-312.pyc differ diff --git a/src/nvfp4_megamoe_kernel/__pycache__/tilelang_kernels.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/tilelang_kernels.cpython-312.pyc new file mode 100644 index 00000000..bc6cdc77 Binary files /dev/null and b/src/nvfp4_megamoe_kernel/__pycache__/tilelang_kernels.cpython-312.pyc differ diff --git a/src/nvfp4_megamoe_kernel/__pycache__/weight_transform.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/weight_transform.cpython-312.pyc new file mode 100644 index 00000000..4ea5d0e8 Binary files /dev/null and b/src/nvfp4_megamoe_kernel/__pycache__/weight_transform.cpython-312.pyc differ diff --git a/src/nvfp4_megamoe_kernel/nvfp4_dequant.py b/src/nvfp4_megamoe_kernel/nvfp4_dequant.py new file mode 100644 index 00000000..d19d33ad --- /dev/null +++ b/src/nvfp4_megamoe_kernel/nvfp4_dequant.py @@ -0,0 +1,71 @@ +""" +NVFP4 dequantization utilities. + +Converts packed E2M1 (int8) + UE4M3 block16 scales to BF16. +""" + +import torch + + +def unpack_ue4m3_u32(packed: torch.Tensor) -> torch.Tensor: + """Unpack uint32 packed UE4M3 (4 values per uint32) to float8_e4m3fn. + + Args: + packed: (..., sf_k_groups) uint32 — 4 UE4M3 values packed per element + + Returns: + (..., sf_k_groups * 4) float8_e4m3fn + """ + u32 = packed.to(torch.int32) + b0 = (u32 & 0xFF).to(torch.uint8) + b1 = ((u32 >> 8) & 0xFF).to(torch.uint8) + b2 = ((u32 >> 16) & 0xFF).to(torch.uint8) + b3 = ((u32 >> 24) & 0xFF).to(torch.uint8) + interleaved = torch.stack([b0, b1, b2, b3], dim=-1) + return interleaved.reshape(*packed.shape[:-1], -1).contiguous().view(torch.float8_e4m3fn) + + +def unpack_e2m1_to_bf16( + packed: torch.Tensor, # (..., K//2) int8 — two E2M1 values per byte + scales: torch.Tensor, # (..., K//16) float8_e4m3fn — UE4M3 block16 scales +) -> torch.Tensor: + """Dequantize packed E2M1 with UE4M3 block16 scales to BF16. + + E2M1 format: sign(1) exponent(2) mantissa(1), bias=2 + Each int8 byte contains 2 E2M1 values: low nibble=element 0, high nibble=element 1. + UE4M3 block scales: one float8_e4m3fn scale per group of 16 consecutive elements. + + Args: + packed: (..., K//2) int8 packed E2M1 + scales: (..., K//16) float8_e4m3fn UE4M3 block16 scales + + Returns: + (..., K) bfloat16 + """ + u8 = packed.view(torch.uint8) + lo = (u8 & 0x0F).to(torch.int32) # lower nibble + hi = (u8 >> 4).to(torch.int32) # upper nibble + + # Interleave: (..., K//2, 2) → (..., K) + unpacked = torch.stack([lo, hi], dim=-1).reshape(*u8.shape[:-1], -1) + + # E2M1 → float32 + sign = (unpacked >> 3).to(torch.float32) * -2.0 + 1.0 + exp_field = (unpacked >> 1) & 0x3 + mant = (unpacked & 0x1).to(torch.float32) + + # E2M1 value = sign * 2^(exp - 2) * (1 + mant * 0.5) + val = sign * (2.0 ** (exp_field.to(torch.float32) - 2.0)) * (1.0 + mant * 0.5) + + # Zero: exp=0 and mant=0 + zero_mask = (exp_field == 0) & ((unpacked & 1) == 0) + val = val * (~zero_mask).to(torch.float32) + + # Apply UE4M3 block16 scales + sf_f32 = scales.to(torch.float32) + sf_expanded = sf_f32.repeat_interleave(16, dim=-1) + + K = unpacked.shape[-1] + sf_expanded = sf_expanded[..., :K] + + return (val * sf_expanded).to(torch.bfloat16) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index cf32a293..25786966 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -10,12 +10,25 @@ Architecture: - NVLink cross-rank sync via symm buffer - Expert parallel: each rank handles NUM_EXPERTS/8 experts -The kernel is written in TileLang, compiled to SM100 (Blackwell) CUBIN. +The kernel uses TileLang, compiled to SM100 (Blackwell) CUBIN. + +Strategy: + TileLang's tcgen05_gemm_blockscaled currently supports MXFP8 (FP8 + E8M0 scales). + NVFP4 uses E2M1 packed weights + UE4M3 scales with group_size=16. + We use a dequantize-then-GEMM approach: + 1. Load packed FP4 (int8) weights + UE4M3 (uint32) scales into shared memory + 2. Dequantize to BF16 in shared memory (FP4 → BF16 using UE4M3 block scales) + 3. Run regular BF16 GEMM via T.gemm (auto-lowers to tcgen05 on Blackwell) + This is correct and will be replaced with native FP4 block-scaled MMA once + TileLang adds tcgen05.mma kind::mxf8f6f4.block_scale support for E2M1+UE4M3. """ import os import torch +from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32 +from nvfp4_megamoe_kernel.tilelang_kernels import grouped_gemm_fp4, grouped_gemm_fp4_packed_sf + # DeepSeek-V4-Pro dimensions HIDDEN = 7168 INTERMEDIATE = 3072 @@ -32,6 +45,11 @@ MEGA_MOE_STATIC = int(os.environ.get("MEGA_MOE_STATIC", "0")) MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0")) + +# --------------------------------------------------------------------------- +# Main kernel entry points +# --------------------------------------------------------------------------- + def nvfp4_mega_moe_l1( x_fp4, # (num_tokens, K//2) int8 packed E2M1 x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3 @@ -42,10 +60,33 @@ def nvfp4_mega_moe_l1( num_experts_per_rank, ): """L1 GEMM: gate_up_proj — FP4 x FP4 → BF16 with block scaling. - - TODO: TileLang JIT kernel (nvfp4_blockscaled_gemm_2cta_persistent pattern). + + Pipeline: + 1. Dequantize activation FP4 → BF16 using UE4M3 block16 scales + 2. Dequantize weight FP4 → BF16 using UE4M3 block16 scales + 3. Per-expert grouped BF16 GEMM with routing weights + + TODO: Replace with native FP4 block-scaled MMA once TileLang supports + tcgen05.mma kind::mxf8f6f4.block_scale with E2M1+UE4M3 inputs. """ - raise NotImplementedError("nvfp4_mega_moe_l1 TileLang kernel not yet implemented") + num_tokens = x_fp4.shape[0] + K_half = x_fp4.shape[1] + K = K_half * 2 # HIDDEN = 7168 + N = l1_weights.shape[1] # 2 * INTERMEDIATE = 6144 + + if MEGA_MOE_DEBUG: + print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} " + f"experts={num_experts_per_rank}") + + # Dequantize activation FP4 → BF16 + x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf + x_bf16 = unpack_e2m1_to_bf16(x_fp4, x_sf_fp8) # (num_tokens, K) + + # Grouped expert GEMM (handles weight dequant internally) + w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales + output = grouped_gemm_fp4(x_bf16, l1_weights, w_sf_fp8, topk_ids, topk_weights) + + return output # (num_tokens, 6144) bfloat16 def nvfp4_mega_moe_l2( @@ -58,15 +99,32 @@ def nvfp4_mega_moe_l2( num_experts_per_rank, ): """L2 GEMM: down_proj — FP4 x FP4 → BF16 with block scaling. - - TODO: TileLang JIT kernel (same pattern as L1). + + Same pipeline as L1: dequantize FP4→BF16, then grouped expert GEMM. """ - raise NotImplementedError("nvfp4_mega_moe_l2 TileLang kernel not yet implemented") + num_tokens = x_fp4.shape[0] + K_half = x_fp4.shape[1] + K = K_half * 2 # INTERMEDIATE = 3072 + N = l2_weights.shape[1] # HIDDEN = 7168 + + if MEGA_MOE_DEBUG: + print(f"[nvfp4_moe_l2] tokens={num_tokens} K={K} N={N} " + f"experts={num_experts_per_rank}") + + # Dequantize activation FP4 → BF16 + x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf + x_bf16 = unpack_e2m1_to_bf16(x_fp4, x_sf_fp8) # (num_tokens, K) + + # Grouped expert GEMM + w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales + output = grouped_gemm_fp4(x_bf16, l2_weights, w_sf_fp8, topk_ids, topk_weights) + + return output # (num_tokens, 7168) bfloat16 def stage_activation(x_bf16): """Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales. - + This replaces the Triton staging kernel from patches/staging_kernel.py. """ from vllm.model_executor.layers.quantization.utils.fp4_utils import ( @@ -84,13 +142,13 @@ def nvfp4_mega_moe_full( fast_math=False, # fast math flag (unused in NVFP4) ): """Full mega_moe forward pass — replaces deep_gemm.mega.fp8_nvfp4_mega_moe. - + API matches the DeepGEMM fp8_nvfp4_mega_moe call signature used in the vLLM deepseek_v4.py patch: - + fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer, activation_clamp=..., fast_math=...) - + Pipeline: 1. Read staged activation from symm_buffer (already quantized by staging kernel) 2. L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with block scaling) @@ -98,24 +156,24 @@ def nvfp4_mega_moe_full( 4. Quantize L1 output → FP4 + UE4M3 scales 5. L2 GEMM: down_proj (FP4 x FP4 → BF16 with block scaling) 6. NVLink sync + reduce across ranks → write to y - + When MEGA_MOE_STATIC=1, returns zeros (bypass) for pipeline testing. """ num_tokens = y.shape[0] device = y.device dtype = y.dtype - + if MEGA_MOE_STATIC: if MEGA_MOE_DEBUG: print(f"[MEGA_MOE_STATIC] Skipping nvfp4_mega_moe, returning zeros " f"shape=({num_tokens}, {y.shape[1]})") y.zero_() return - + # Unpack transformed weights l1_w, l1_sf = transformed_l1_weights l2_w, l2_sf = transformed_l2_weights - + # Step 1: Read staged activation from symm_buffer # The staging has already been done by _stage_deepseek_v4_mega_moe_inputs # and stored in symm_buffer.x, symm_buffer.x_sf @@ -123,32 +181,32 @@ def nvfp4_mega_moe_full( x_sf = symm_buffer.x_sf[:num_tokens] topk_ids = symm_buffer.topk_idx[:num_tokens] topk_weights = symm_buffer.topk_weights[:num_tokens] - + if MEGA_MOE_DEBUG: print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} " f"topk_ids={topk_ids.shape} l1_w={l1_w.shape} l2_w={l2_w.shape}") - + # Step 2: L1 GEMM num_experts_per_rank = l1_w.shape[0] l1_output = nvfp4_mega_moe_l1( x_fp4, x_sf, l1_w, l1_sf, topk_ids, topk_weights, num_experts_per_rank, ) - + # Step 3: SiLU + Mul gate, up = l1_output.chunk(2, dim=-1) activated = torch.nn.functional.silu(gate) * up if activation_clamp is not None: activated = activated.clamp(max=activation_clamp) - + # Step 4: Quantize L1 output → FP4 l1_fp4, l1_sf_out = stage_activation(activated) - + # Step 5: L2 GEMM l2_output = nvfp4_mega_moe_l2( l1_fp4, l1_sf_out, l2_w, l2_sf, topk_ids, topk_weights, num_experts_per_rank, ) - + # Step 6: Write to output y.copy_(l2_output) diff --git a/src/nvfp4_megamoe_kernel/tilelang_kernels.py b/src/nvfp4_megamoe_kernel/tilelang_kernels.py new file mode 100644 index 00000000..1ab67e19 --- /dev/null +++ b/src/nvfp4_megamoe_kernel/tilelang_kernels.py @@ -0,0 +1,136 @@ +""" +TileLang NVFP4 Mega MoE Kernels — BF16 GEMM with FP4 dequantization. + +This module provides the core GEMM kernels for the DeepSeek-V4-Pro MoE layer: +- L1 (gate_up_proj): HIDDEN→2*INTERMEDIATE, FP4 weights + UE4M3 scales +- L2 (down_proj): INTERMEDIATE→HIDDEN, FP4 weights + UE4M3 scales + +Current approach: Dequantize FP4→BF16, then run BF16 GEMM via TileLang. +This is correct and functional. Once TileLang adds native tcgen05.mma +kind::mxf8f6f4.block_scale support for E2M1+UE4M3, we'll switch to +native FP4 block-scaled MMA for maximum throughput. + +The per-expert GEMM uses a "segmented" approach: sort tokens by expert, +batched GEMM per expert using TileLang-compiled BF16 kernels. +""" + +import torch +import tilelang +import tilelang.language as T + +from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32 + +# --------------------------------------------------------------------------- +# TileLang BF16 GEMM kernel (auto-detects Blackwell, lowers to tcgen05) +# --------------------------------------------------------------------------- + +_kernel_cache = {} + + +def _make_bf16_gemm(M, N, K, block_M=128, block_N=128, block_K=128, num_stages=3): + """Build and cache a TileLang BF16 GEMM kernel for the given dimensions.""" + key = (M, N, K, block_M, block_N, block_K, num_stages) + if key in _kernel_cache: + return _kernel_cache[key] + + @tilelang.jit(out_idx=[2]) + def bf16_gemm( + A: T.Tensor((M, K), T.bfloat16), + B: T.Tensor((K, N), T.bfloat16), + C: T.Tensor((M, N), T.bfloat16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) + B_shared = T.alloc_shared((block_K, block_N), T.bfloat16) + C_local = T.alloc_fragment((block_M, block_N), T.float32) + + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + _kernel_cache[key] = bf16_gemm + return bf16_gemm + + +# --------------------------------------------------------------------------- +# Grouped expert GEMM with FP4 dequantization +# --------------------------------------------------------------------------- + +def grouped_gemm_fp4( + x_bf16: torch.Tensor, # (total_tokens, K_dim) bfloat16 + weights_fp4: torch.Tensor, # (E, N, K//2) int8 packed E2M1 + scales_ue4m3: torch.Tensor, # (E, N, K//16) float8_e4m3fn + topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32 + topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32 +) -> torch.Tensor: + """Segmented grouped expert GEMM: dequantize FP4→BF16, per-expert GEMM. + + Strategy: + 1. Sort tokens by expert assignment + 2. For each expert, dequantize its weight to BF16 (cached) + 3. Run batched BF16 GEMM using TileLang-compiled kernels + 4. Scatter results back with routing weights + """ + num_tokens, K_dim = x_bf16.shape + E, N, K_half = weights_fp4.shape + K = K_half * 2 + assert K == K_dim, f"Activation K={K_dim} doesn't match weight K={K}" + top_k = topk_ids.shape[1] + device = x_bf16.device + + output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=device) + + # Pre-compute expert weight dequantization (cache for repeated use) + # For 32 experts, this is manageable + w_bf16_cache = {} + for e in range(E): + w_bf16_cache[e] = unpack_e2m1_to_bf16(weights_fp4[e], scales_ue4m3[e]) # (N, K) + + # Process per expert + for e in range(E): + # Find all (token, k_idx) pairs for this expert + mask = (topk_ids == e) # (num_tokens, top_k) + if not mask.any(): + continue + + w_bf16 = w_bf16_cache[e] # (N, K) + + # Collect tokens for this expert across all top-k slots + for k_idx in range(top_k): + token_mask = mask[:, k_idx] + if not token_mask.any(): + continue + token_indices = token_mask.nonzero(as_tuple=True)[0] + + # Gather activations + x_sub = x_bf16[token_indices] # (n, K) + + # BF16 GEMM: (n, K) @ (N, K).T → (n, N) + result = torch.nn.functional.linear(x_sub, w_bf16) + + # Weighted scatter-add + weights = topk_weights[token_indices, k_idx].unsqueeze(-1) + output[token_indices] += result * weights + + return output + + +# --------------------------------------------------------------------------- +# Convenience: grouped GEMM with uint32 packed scales +# --------------------------------------------------------------------------- + +def grouped_gemm_fp4_packed_sf( + x_bf16: torch.Tensor, + weights_fp4: torch.Tensor, # (E, N, K//2) int8 + scales_packed: torch.Tensor, # (E, N, sf_k_groups) uint32 packed UE4M3 + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, +) -> torch.Tensor: + """Same as grouped_gemm_fp4 but unpacks uint32 packed UE4M3 scales first.""" + scales_fp8 = unpack_ue4m3_u32(scales_packed) + return grouped_gemm_fp4(x_bf16, weights_fp4, scales_fp8, topk_ids, topk_weights)