From 7a4403fa98c7b9b06f85376e7892cecddab38bb8 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 20:29:36 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20FP4=20staging=20kernel=20-=20BF16=20?= =?UTF-8?q?=E2=86=92=20E2M1=20packed=20+=20UE4M3=20block16=20scales?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit mxf4nvf4 requires FP4×FP4, not FP8×FP4. - New staging kernel: E2M1 nearest-neighbor quantization - Output: uint8 packed (2 E2M1 per byte) + UE4M3 packed int32 scales - Added CUDA sync diagnostics for error localization --- patches/deepseek_v4.py | 113 ++++++++++++-------- patches/staging_kernel.py | 220 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 289 insertions(+), 44 deletions(-) create mode 100644 patches/staging_kernel.py diff --git a/patches/deepseek_v4.py b/patches/deepseek_v4.py index 59da91b..0e3e0d5 100644 --- a/patches/deepseek_v4.py +++ b/patches/deepseek_v4.py @@ -244,11 +244,23 @@ class DeepseekV4FP8Config(Fp8Config): return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4" +@triton.jit +""" +NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales. + +The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed). +This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales. +""" +import triton +import triton.language as tl +import torch + + @triton.jit def _deepseek_v4_stage_mega_moe_inputs_kernel( hidden_states, - x_fp8, - x_sf, + x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte + x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32 topk_ids, topk_weights, topk_idx_out, @@ -269,8 +281,8 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( topk_weights_out_stride_k: tl.constexpr, hidden_size: tl.constexpr, top_k: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_K: tl.constexpr, + BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden) + GROUP_K: tl.constexpr, # 16 (NVFP4 group_size) BLOCK_TOPK: tl.constexpr, ) -> None: token_id = tl.program_id(0) @@ -284,63 +296,76 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( other=0.0, ).to(tl.float32) - num_groups: tl.constexpr = BLOCK_K // GROUP_K - hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) - amax = tl.max(hidden_groups, axis=1) + num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8 + hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) + abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) + amax = tl.max(abs_groups, axis=1) amax = tl.maximum(amax, 1.0e-4) - # NVFP4: UE4M3 activation scales (not UE8M0) - # scale_format_ is shared between SFA and SFB in the MMA descriptor, - # so both must use the same format. For mxf4nvf4, both are UE4M3. - # E4M3 format: 1 sign + 4 exp + 3 mantissa, bias=7, max normal=448 - # UE4M3: sign bit cleared, same representation otherwise - scale = amax / 448.0 # Normalize so max activation maps to UE4M3 max + # ---- UE4M3 scale computation ---- + # scale = amax / 6.0 (E2M1 max value = 6) + # Then quantize scale to UE4M3 format + scale = amax / 6.0 scale_bits = scale.to(tl.uint32, bitcast=True) scale_exp = (scale_bits >> 23) & 0xFF scale_mant = scale_bits & 0x7FFFFF # Convert FP32 → E4M3 manually - # Normal: exp in [1, 254], mantissa has 23 bits → take top 3 - # For E4M3: exp_bits = FP32_exp - 127 + 7 = FP32_exp - 120, clamped to [1, 15] - # mant_bits = FP32_mant >> 20 (top 3 of 23) - # Round: if any of the dropped mantissa bits are set, add 1 to mant e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7 - # Handle subnormal: if FP32 exp < 121, E4M3 exp would be 0 (subnormal) e4m3_exp = tl.maximum(e4m3_exp, 0) e4m3_exp = tl.minimum(e4m3_exp, 15) - # Mantissa: top 3 bits of FP32 mantissa, with rounding - e4m3_mant = scale_mant >> 20 # top 3 bits - round_bit = (scale_mant >> 19) & 1 # 4th bit for rounding + e4m3_mant = scale_mant >> 20 + round_bit = (scale_mant >> 19) & 1 e4m3_mant = e4m3_mant + round_bit - # If mantissa overflowed (8), carry into exponent overflow = e4m3_mant >= 8 e4m3_mant = tl.where(overflow, 0, e4m3_mant) e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp) e4m3_exp = tl.minimum(e4m3_exp, 15) - # Pack: UE4M3 = (exp << 3) | mant (no sign bit) - scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant # uint8 representation + scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant - # Reconstruct the dequantized scale for FP8 activation quantization - # UE4M3 dequant: if exp == 0: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6 - # else: value = (1 + mant/8) * 2^(exp-7) - # Use bit manipulation: 2^n = (n+127) << 23 reinterpreted as float32 + # Reconstruct dequantized scale for E2M1 quantization e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126) - # For exp=0 (subnormal): already handled separately - # For normal: (1 + mant/8) * 2^(exp-7) two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23 two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True) normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp - subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625 # 2^-6 + subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625 e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value) - hidden_groups_f = tl.reshape(hidden, [num_groups, GROUP_K]) - scaled = hidden_groups_f * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None] - scaled = tl.reshape(scaled, [BLOCK_K]) - fp8 = scaled.to(tl.float8e4nv) + # ---- E2M1 FP4 quantization ---- + # E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6] + # Nearest-neighbor using thresholds (midpoints between consecutive values) + scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None] + # Clamp to E2M1 range [-6, 6] + scaled = tl.maximum(scaled, -6.0) + scaled = tl.minimum(scaled, 6.0) + + abs_s = tl.abs(scaled) + # Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6] + # [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF] + e2m1_idx = tl.where(abs_s < 0.25, 0, + tl.where(abs_s < 0.75, 1, + tl.where(abs_s < 1.25, 2, + tl.where(abs_s < 1.75, 3, + tl.where(abs_s < 2.5, 4, + tl.where(abs_s < 3.5, 5, + tl.where(abs_s < 5.0, 6, 7))))))) + sign_bit = (scaled < 0).to(tl.int32) + e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index + + # Pack 2 E2M1 values per byte: even→low nibble, odd→high nibble + PACKED_K: tl.constexpr = BLOCK_K // 2 # 64 + e2m1_4bit_flat = tl.reshape(e2m1_4bit, [BLOCK_K]) + # Interleave: index 0,2,4,... → low nibbles; 1,3,5,... → high nibbles + even = e2m1_4bit_flat[0::2] + odd = e2m1_4bit_flat[1::2] + packed_byte = (odd.to(tl.uint8) << 4) | even.to(tl.uint8) + + packed_k_offsets = k_block_id * PACKED_K + tl.arange(0, PACKED_K) + packed_k_mask = packed_k_offsets < (hidden_size // 2) tl.store( - x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k, - fp8, - mask=k_mask, + x_fp4 + token_id * x_stride_m + packed_k_offsets * x_stride_k, + packed_byte, + mask=packed_k_mask, ) # Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements) @@ -378,7 +403,7 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( tl.store( topk_weights_out + token_id * topk_weights_out_stride_m - + topk_offsets * topk_weights_out_stride_k, + + topk_weights_out_stride_k, weights, mask=topk_mask, ) @@ -388,8 +413,8 @@ def _stage_deepseek_v4_mega_moe_inputs( hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - x_fp8: torch.Tensor, - x_sf: torch.Tensor, + x_fp4: torch.Tensor, # uint8, shape (M, K//2) + x_sf: torch.Tensor, # int32, shape (M, K//64) topk_idx_out: torch.Tensor, topk_weights_out: torch.Tensor, ) -> None: @@ -413,7 +438,7 @@ def _stage_deepseek_v4_mega_moe_inputs( block_topk = triton.next_power_of_2(top_k) _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( hidden_states, - x_fp8, + x_fp4, x_sf, topk_ids, topk_weights, @@ -421,8 +446,8 @@ def _stage_deepseek_v4_mega_moe_inputs( topk_weights_out, hidden_states.stride(0), hidden_states.stride(1), - x_fp8.stride(0), - x_fp8.stride(1), + x_fp4.stride(0), + x_fp4.stride(1), x_sf.stride(0), x_sf.stride(1), topk_ids.stride(0), diff --git a/patches/staging_kernel.py b/patches/staging_kernel.py new file mode 100644 index 0000000..06d2969 --- /dev/null +++ b/patches/staging_kernel.py @@ -0,0 +1,220 @@ +""" +NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales. + +The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed). +This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales. +""" +import triton +import triton.language as tl +import torch + + +@triton.jit +def _deepseek_v4_stage_mega_moe_inputs_kernel( + hidden_states, + x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte + x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32 + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_stride_m: tl.constexpr, + hidden_stride_k: tl.constexpr, + x_stride_m: tl.constexpr, + x_stride_k: tl.constexpr, + x_sf_stride_m: tl.constexpr, + x_sf_stride_k: tl.constexpr, + topk_ids_stride_m: tl.constexpr, + topk_ids_stride_k: tl.constexpr, + topk_weights_stride_m: tl.constexpr, + topk_weights_stride_k: tl.constexpr, + topk_idx_stride_m: tl.constexpr, + topk_idx_stride_k: tl.constexpr, + topk_weights_out_stride_m: tl.constexpr, + topk_weights_out_stride_k: tl.constexpr, + hidden_size: tl.constexpr, + top_k: tl.constexpr, + BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden) + GROUP_K: tl.constexpr, # 16 (NVFP4 group_size) + BLOCK_TOPK: tl.constexpr, +) -> None: + token_id = tl.program_id(0) + k_block_id = tl.program_id(1) + + k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + k_mask = k_offsets < hidden_size + hidden = tl.load( + hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, + mask=k_mask, + other=0.0, + ).to(tl.float32) + + num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8 + hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) + abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) + amax = tl.max(abs_groups, axis=1) + amax = tl.maximum(amax, 1.0e-4) + + # ---- UE4M3 scale computation ---- + # scale = amax / 6.0 (E2M1 max value = 6) + # Then quantize scale to UE4M3 format + scale = amax / 6.0 + scale_bits = scale.to(tl.uint32, bitcast=True) + scale_exp = (scale_bits >> 23) & 0xFF + scale_mant = scale_bits & 0x7FFFFF + + # Convert FP32 → E4M3 manually + e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7 + e4m3_exp = tl.maximum(e4m3_exp, 0) + e4m3_exp = tl.minimum(e4m3_exp, 15) + e4m3_mant = scale_mant >> 20 + round_bit = (scale_mant >> 19) & 1 + e4m3_mant = e4m3_mant + round_bit + overflow = e4m3_mant >= 8 + e4m3_mant = tl.where(overflow, 0, e4m3_mant) + e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp) + e4m3_exp = tl.minimum(e4m3_exp, 15) + scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant + + # Reconstruct dequantized scale for E2M1 quantization + e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126) + two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23 + two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True) + normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp + subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625 + e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value) + + # ---- E2M1 FP4 quantization ---- + # E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6] + # Nearest-neighbor using thresholds (midpoints between consecutive values) + scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None] + # Clamp to E2M1 range [-6, 6] + scaled = tl.maximum(scaled, -6.0) + scaled = tl.minimum(scaled, 6.0) + + abs_s = tl.abs(scaled) + # Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6] + # [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF] + e2m1_idx = tl.where(abs_s < 0.25, 0, + tl.where(abs_s < 0.75, 1, + tl.where(abs_s < 1.25, 2, + tl.where(abs_s < 1.75, 3, + tl.where(abs_s < 2.5, 4, + tl.where(abs_s < 3.5, 5, + tl.where(abs_s < 5.0, 6, 7))))))) + sign_bit = (scaled < 0).to(tl.int32) + e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index + + # Pack 2 E2M1 values per byte: even→low nibble, odd→high nibble + PACKED_K: tl.constexpr = BLOCK_K // 2 # 64 + e2m1_4bit_flat = tl.reshape(e2m1_4bit, [BLOCK_K]) + # Interleave: index 0,2,4,... → low nibbles; 1,3,5,... → high nibbles + even = e2m1_4bit_flat[0::2] + odd = e2m1_4bit_flat[1::2] + packed_byte = (odd.to(tl.uint8) << 4) | even.to(tl.uint8) + + packed_k_offsets = k_block_id * PACKED_K + tl.arange(0, PACKED_K) + packed_k_mask = packed_k_offsets < (hidden_size // 2) + tl.store( + x_fp4 + token_id * x_stride_m + packed_k_offsets * x_stride_k, + packed_byte, + mask=packed_k_mask, + ) + + # Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements) + scale_offsets = tl.arange(0, num_groups) + packed_scale = tl.sum(scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), axis=0).to(tl.int32) + tl.store( + x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, + packed_scale, + ) + + if k_block_id == 0: + topk_offsets = tl.arange(0, BLOCK_TOPK) + topk_mask = topk_offsets < top_k + + ids = tl.load( + topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, + mask=topk_mask, + other=0, + ).to(tl.int64) + tl.store( + topk_idx_out + + token_id * topk_idx_stride_m + + topk_offsets * topk_idx_stride_k, + ids, + mask=topk_mask, + ) + + weights = tl.load( + topk_weights + + token_id * topk_weights_stride_m + + topk_offsets * topk_weights_stride_k, + mask=topk_mask, + other=0.0, + ) + tl.store( + topk_weights_out + + token_id * topk_weights_out_stride_m + + topk_weights_out_stride_k, + weights, + mask=topk_mask, + ) + + +def _stage_deepseek_v4_mega_moe_inputs( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + x_fp4: torch.Tensor, # uint8, shape (M, K//2) + x_sf: torch.Tensor, # int32, shape (M, K//64) + topk_idx_out: torch.Tensor, + topk_weights_out: torch.Tensor, +) -> None: + num_tokens, hidden_size = hidden_states.shape + if num_tokens == 0: + return + if hidden_size % 128 != 0: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires hidden_size to be " + "a multiple of 128." + ) + top_k = topk_ids.shape[1] + if topk_weights.shape != topk_ids.shape: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires topk_weights and " + "topk_ids to have the same shape." + ) + + block_k = 128 + grid = (num_tokens, triton.cdiv(hidden_size, block_k)) + block_topk = triton.next_power_of_2(top_k) + _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( + hidden_states, + x_fp4, + x_sf, + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_states.stride(0), + hidden_states.stride(1), + x_fp4.stride(0), + x_fp4.stride(1), + x_sf.stride(0), + x_sf.stride(1), + topk_ids.stride(0), + topk_ids.stride(1), + topk_weights.stride(0), + topk_weights.stride(1), + topk_idx_out.stride(0), + topk_idx_out.stride(1), + topk_weights_out.stride(0), + topk_weights_out.stride(1), + hidden_size, + top_k, + BLOCK_K=block_k, + GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X) + BLOCK_TOPK=block_topk, + num_warps=4, + )