feat: FP4 staging kernel - BF16 → E2M1 packed + UE4M3 block16 scales

mxf4nvf4 requires FP4×FP4, not FP8×FP4.
- New staging kernel: E2M1 nearest-neighbor quantization
- Output: uint8 packed (2 E2M1 per byte) + UE4M3 packed int32 scales
- Added CUDA sync diagnostics for error localization
This commit is contained in:
2026-05-11 20:29:36 +00:00
parent 0fd2d4f078
commit 7a4403fa98
2 changed files with 289 additions and 44 deletions

View File

@@ -244,11 +244,23 @@ class DeepseekV4FP8Config(Fp8Config):
return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4"
@triton.jit
"""
NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales.
The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed).
This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales.
"""
import triton
import triton.language as tl
import torch
@triton.jit
def _deepseek_v4_stage_mega_moe_inputs_kernel(
hidden_states,
x_fp8,
x_sf,
x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte
x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32
topk_ids,
topk_weights,
topk_idx_out,
@@ -269,8 +281,8 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
topk_weights_out_stride_k: tl.constexpr,
hidden_size: tl.constexpr,
top_k: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_K: tl.constexpr,
BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden)
GROUP_K: tl.constexpr, # 16 (NVFP4 group_size)
BLOCK_TOPK: tl.constexpr,
) -> None:
token_id = tl.program_id(0)
@@ -284,63 +296,76 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
other=0.0,
).to(tl.float32)
num_groups: tl.constexpr = BLOCK_K // GROUP_K
hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
amax = tl.max(hidden_groups, axis=1)
num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
amax = tl.max(abs_groups, axis=1)
amax = tl.maximum(amax, 1.0e-4)
# NVFP4: UE4M3 activation scales (not UE8M0)
# scale_format_ is shared between SFA and SFB in the MMA descriptor,
# so both must use the same format. For mxf4nvf4, both are UE4M3.
# E4M3 format: 1 sign + 4 exp + 3 mantissa, bias=7, max normal=448
# UE4M3: sign bit cleared, same representation otherwise
scale = amax / 448.0 # Normalize so max activation maps to UE4M3 max
# ---- UE4M3 scale computation ----
# scale = amax / 6.0 (E2M1 max value = 6)
# Then quantize scale to UE4M3 format
scale = amax / 6.0
scale_bits = scale.to(tl.uint32, bitcast=True)
scale_exp = (scale_bits >> 23) & 0xFF
scale_mant = scale_bits & 0x7FFFFF
# Convert FP32 → E4M3 manually
# Normal: exp in [1, 254], mantissa has 23 bits → take top 3
# For E4M3: exp_bits = FP32_exp - 127 + 7 = FP32_exp - 120, clamped to [1, 15]
# mant_bits = FP32_mant >> 20 (top 3 of 23)
# Round: if any of the dropped mantissa bits are set, add 1 to mant
e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7
# Handle subnormal: if FP32 exp < 121, E4M3 exp would be 0 (subnormal)
e4m3_exp = tl.maximum(e4m3_exp, 0)
e4m3_exp = tl.minimum(e4m3_exp, 15)
# Mantissa: top 3 bits of FP32 mantissa, with rounding
e4m3_mant = scale_mant >> 20 # top 3 bits
round_bit = (scale_mant >> 19) & 1 # 4th bit for rounding
e4m3_mant = scale_mant >> 20
round_bit = (scale_mant >> 19) & 1
e4m3_mant = e4m3_mant + round_bit
# If mantissa overflowed (8), carry into exponent
overflow = e4m3_mant >= 8
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
e4m3_exp = tl.minimum(e4m3_exp, 15)
# Pack: UE4M3 = (exp << 3) | mant (no sign bit)
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant # uint8 representation
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant
# Reconstruct the dequantized scale for FP8 activation quantization
# UE4M3 dequant: if exp == 0: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6
# else: value = (1 + mant/8) * 2^(exp-7)
# Use bit manipulation: 2^n = (n+127) << 23 reinterpreted as float32
# Reconstruct dequantized scale for E2M1 quantization
e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126)
# For exp=0 (subnormal): already handled separately
# For normal: (1 + mant/8) * 2^(exp-7)
two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23
two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True)
normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp
subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625 # 2^-6
subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625
e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value)
hidden_groups_f = tl.reshape(hidden, [num_groups, GROUP_K])
scaled = hidden_groups_f * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
scaled = tl.reshape(scaled, [BLOCK_K])
fp8 = scaled.to(tl.float8e4nv)
# ---- E2M1 FP4 quantization ----
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]
# Nearest-neighbor using thresholds (midpoints between consecutive values)
scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
# Clamp to E2M1 range [-6, 6]
scaled = tl.maximum(scaled, -6.0)
scaled = tl.minimum(scaled, 6.0)
abs_s = tl.abs(scaled)
# Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6]
# [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF]
e2m1_idx = tl.where(abs_s < 0.25, 0,
tl.where(abs_s < 0.75, 1,
tl.where(abs_s < 1.25, 2,
tl.where(abs_s < 1.75, 3,
tl.where(abs_s < 2.5, 4,
tl.where(abs_s < 3.5, 5,
tl.where(abs_s < 5.0, 6, 7)))))))
sign_bit = (scaled < 0).to(tl.int32)
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index
# Pack 2 E2M1 values per byte: even→low nibble, odd→high nibble
PACKED_K: tl.constexpr = BLOCK_K // 2 # 64
e2m1_4bit_flat = tl.reshape(e2m1_4bit, [BLOCK_K])
# Interleave: index 0,2,4,... → low nibbles; 1,3,5,... → high nibbles
even = e2m1_4bit_flat[0::2]
odd = e2m1_4bit_flat[1::2]
packed_byte = (odd.to(tl.uint8) << 4) | even.to(tl.uint8)
packed_k_offsets = k_block_id * PACKED_K + tl.arange(0, PACKED_K)
packed_k_mask = packed_k_offsets < (hidden_size // 2)
tl.store(
x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k,
fp8,
mask=k_mask,
x_fp4 + token_id * x_stride_m + packed_k_offsets * x_stride_k,
packed_byte,
mask=packed_k_mask,
)
# Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
@@ -378,7 +403,7 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
tl.store(
topk_weights_out
+ token_id * topk_weights_out_stride_m
+ topk_offsets * topk_weights_out_stride_k,
+ topk_weights_out_stride_k,
weights,
mask=topk_mask,
)
@@ -388,8 +413,8 @@ def _stage_deepseek_v4_mega_moe_inputs(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
x_fp8: torch.Tensor,
x_sf: torch.Tensor,
x_fp4: torch.Tensor, # uint8, shape (M, K//2)
x_sf: torch.Tensor, # int32, shape (M, K//64)
topk_idx_out: torch.Tensor,
topk_weights_out: torch.Tensor,
) -> None:
@@ -413,7 +438,7 @@ def _stage_deepseek_v4_mega_moe_inputs(
block_topk = triton.next_power_of_2(top_k)
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
hidden_states,
x_fp8,
x_fp4,
x_sf,
topk_ids,
topk_weights,
@@ -421,8 +446,8 @@ def _stage_deepseek_v4_mega_moe_inputs(
topk_weights_out,
hidden_states.stride(0),
hidden_states.stride(1),
x_fp8.stride(0),
x_fp8.stride(1),
x_fp4.stride(0),
x_fp4.stride(1),
x_sf.stride(0),
x_sf.stride(1),
topk_ids.stride(0),

220
patches/staging_kernel.py Normal file
View File

@@ -0,0 +1,220 @@
"""
NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales.
The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed).
This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales.
"""
import triton
import triton.language as tl
import torch
@triton.jit
def _deepseek_v4_stage_mega_moe_inputs_kernel(
hidden_states,
x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte
x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32
topk_ids,
topk_weights,
topk_idx_out,
topk_weights_out,
hidden_stride_m: tl.constexpr,
hidden_stride_k: tl.constexpr,
x_stride_m: tl.constexpr,
x_stride_k: tl.constexpr,
x_sf_stride_m: tl.constexpr,
x_sf_stride_k: tl.constexpr,
topk_ids_stride_m: tl.constexpr,
topk_ids_stride_k: tl.constexpr,
topk_weights_stride_m: tl.constexpr,
topk_weights_stride_k: tl.constexpr,
topk_idx_stride_m: tl.constexpr,
topk_idx_stride_k: tl.constexpr,
topk_weights_out_stride_m: tl.constexpr,
topk_weights_out_stride_k: tl.constexpr,
hidden_size: tl.constexpr,
top_k: tl.constexpr,
BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden)
GROUP_K: tl.constexpr, # 16 (NVFP4 group_size)
BLOCK_TOPK: tl.constexpr,
) -> None:
token_id = tl.program_id(0)
k_block_id = tl.program_id(1)
k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
k_mask = k_offsets < hidden_size
hidden = tl.load(
hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k,
mask=k_mask,
other=0.0,
).to(tl.float32)
num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
amax = tl.max(abs_groups, axis=1)
amax = tl.maximum(amax, 1.0e-4)
# ---- UE4M3 scale computation ----
# scale = amax / 6.0 (E2M1 max value = 6)
# Then quantize scale to UE4M3 format
scale = amax / 6.0
scale_bits = scale.to(tl.uint32, bitcast=True)
scale_exp = (scale_bits >> 23) & 0xFF
scale_mant = scale_bits & 0x7FFFFF
# Convert FP32 → E4M3 manually
e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7
e4m3_exp = tl.maximum(e4m3_exp, 0)
e4m3_exp = tl.minimum(e4m3_exp, 15)
e4m3_mant = scale_mant >> 20
round_bit = (scale_mant >> 19) & 1
e4m3_mant = e4m3_mant + round_bit
overflow = e4m3_mant >= 8
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
e4m3_exp = tl.minimum(e4m3_exp, 15)
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant
# Reconstruct dequantized scale for E2M1 quantization
e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126)
two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23
two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True)
normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp
subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625
e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value)
# ---- E2M1 FP4 quantization ----
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]
# Nearest-neighbor using thresholds (midpoints between consecutive values)
scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
# Clamp to E2M1 range [-6, 6]
scaled = tl.maximum(scaled, -6.0)
scaled = tl.minimum(scaled, 6.0)
abs_s = tl.abs(scaled)
# Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6]
# [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF]
e2m1_idx = tl.where(abs_s < 0.25, 0,
tl.where(abs_s < 0.75, 1,
tl.where(abs_s < 1.25, 2,
tl.where(abs_s < 1.75, 3,
tl.where(abs_s < 2.5, 4,
tl.where(abs_s < 3.5, 5,
tl.where(abs_s < 5.0, 6, 7)))))))
sign_bit = (scaled < 0).to(tl.int32)
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index
# Pack 2 E2M1 values per byte: even→low nibble, odd→high nibble
PACKED_K: tl.constexpr = BLOCK_K // 2 # 64
e2m1_4bit_flat = tl.reshape(e2m1_4bit, [BLOCK_K])
# Interleave: index 0,2,4,... → low nibbles; 1,3,5,... → high nibbles
even = e2m1_4bit_flat[0::2]
odd = e2m1_4bit_flat[1::2]
packed_byte = (odd.to(tl.uint8) << 4) | even.to(tl.uint8)
packed_k_offsets = k_block_id * PACKED_K + tl.arange(0, PACKED_K)
packed_k_mask = packed_k_offsets < (hidden_size // 2)
tl.store(
x_fp4 + token_id * x_stride_m + packed_k_offsets * x_stride_k,
packed_byte,
mask=packed_k_mask,
)
# Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
scale_offsets = tl.arange(0, num_groups)
packed_scale = tl.sum(scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), axis=0).to(tl.int32)
tl.store(
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
packed_scale,
)
if k_block_id == 0:
topk_offsets = tl.arange(0, BLOCK_TOPK)
topk_mask = topk_offsets < top_k
ids = tl.load(
topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k,
mask=topk_mask,
other=0,
).to(tl.int64)
tl.store(
topk_idx_out
+ token_id * topk_idx_stride_m
+ topk_offsets * topk_idx_stride_k,
ids,
mask=topk_mask,
)
weights = tl.load(
topk_weights
+ token_id * topk_weights_stride_m
+ topk_offsets * topk_weights_stride_k,
mask=topk_mask,
other=0.0,
)
tl.store(
topk_weights_out
+ token_id * topk_weights_out_stride_m
+ topk_weights_out_stride_k,
weights,
mask=topk_mask,
)
def _stage_deepseek_v4_mega_moe_inputs(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
x_fp4: torch.Tensor, # uint8, shape (M, K//2)
x_sf: torch.Tensor, # int32, shape (M, K//64)
topk_idx_out: torch.Tensor,
topk_weights_out: torch.Tensor,
) -> None:
num_tokens, hidden_size = hidden_states.shape
if num_tokens == 0:
return
if hidden_size % 128 != 0:
raise ValueError(
"DeepSeek V4 MegaMoE input staging requires hidden_size to be "
"a multiple of 128."
)
top_k = topk_ids.shape[1]
if topk_weights.shape != topk_ids.shape:
raise ValueError(
"DeepSeek V4 MegaMoE input staging requires topk_weights and "
"topk_ids to have the same shape."
)
block_k = 128
grid = (num_tokens, triton.cdiv(hidden_size, block_k))
block_topk = triton.next_power_of_2(top_k)
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
hidden_states,
x_fp4,
x_sf,
topk_ids,
topk_weights,
topk_idx_out,
topk_weights_out,
hidden_states.stride(0),
hidden_states.stride(1),
x_fp4.stride(0),
x_fp4.stride(1),
x_sf.stride(0),
x_sf.stride(1),
topk_ids.stride(0),
topk_ids.stride(1),
topk_weights.stride(0),
topk_weights.stride(1),
topk_idx_out.stride(0),
topk_idx_out.stride(1),
topk_weights_out.stride(0),
topk_weights_out.stride(1),
hidden_size,
top_k,
BLOCK_K=block_k,
GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X)
BLOCK_TOPK=block_topk,
num_warps=4,
)