feat: FP4 staging kernel - BF16 → E2M1 packed + UE4M3 block16 scales
mxf4nvf4 requires FP4×FP4, not FP8×FP4. - New staging kernel: E2M1 nearest-neighbor quantization - Output: uint8 packed (2 E2M1 per byte) + UE4M3 packed int32 scales - Added CUDA sync diagnostics for error localization
This commit is contained in:
@@ -244,11 +244,23 @@ class DeepseekV4FP8Config(Fp8Config):
|
||||
return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4"
|
||||
|
||||
|
||||
@triton.jit
|
||||
"""
|
||||
NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales.
|
||||
|
||||
The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed).
|
||||
This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales.
|
||||
"""
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
hidden_states,
|
||||
x_fp8,
|
||||
x_sf,
|
||||
x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte
|
||||
x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
topk_idx_out,
|
||||
@@ -269,8 +281,8 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
topk_weights_out_stride_k: tl.constexpr,
|
||||
hidden_size: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
GROUP_K: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden)
|
||||
GROUP_K: tl.constexpr, # 16 (NVFP4 group_size)
|
||||
BLOCK_TOPK: tl.constexpr,
|
||||
) -> None:
|
||||
token_id = tl.program_id(0)
|
||||
@@ -284,63 +296,76 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
num_groups: tl.constexpr = BLOCK_K // GROUP_K
|
||||
hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
|
||||
amax = tl.max(hidden_groups, axis=1)
|
||||
num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8
|
||||
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
|
||||
abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
|
||||
amax = tl.max(abs_groups, axis=1)
|
||||
amax = tl.maximum(amax, 1.0e-4)
|
||||
|
||||
# NVFP4: UE4M3 activation scales (not UE8M0)
|
||||
# scale_format_ is shared between SFA and SFB in the MMA descriptor,
|
||||
# so both must use the same format. For mxf4nvf4, both are UE4M3.
|
||||
# E4M3 format: 1 sign + 4 exp + 3 mantissa, bias=7, max normal=448
|
||||
# UE4M3: sign bit cleared, same representation otherwise
|
||||
scale = amax / 448.0 # Normalize so max activation maps to UE4M3 max
|
||||
# ---- UE4M3 scale computation ----
|
||||
# scale = amax / 6.0 (E2M1 max value = 6)
|
||||
# Then quantize scale to UE4M3 format
|
||||
scale = amax / 6.0
|
||||
scale_bits = scale.to(tl.uint32, bitcast=True)
|
||||
scale_exp = (scale_bits >> 23) & 0xFF
|
||||
scale_mant = scale_bits & 0x7FFFFF
|
||||
|
||||
# Convert FP32 → E4M3 manually
|
||||
# Normal: exp in [1, 254], mantissa has 23 bits → take top 3
|
||||
# For E4M3: exp_bits = FP32_exp - 127 + 7 = FP32_exp - 120, clamped to [1, 15]
|
||||
# mant_bits = FP32_mant >> 20 (top 3 of 23)
|
||||
# Round: if any of the dropped mantissa bits are set, add 1 to mant
|
||||
e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7
|
||||
# Handle subnormal: if FP32 exp < 121, E4M3 exp would be 0 (subnormal)
|
||||
e4m3_exp = tl.maximum(e4m3_exp, 0)
|
||||
e4m3_exp = tl.minimum(e4m3_exp, 15)
|
||||
# Mantissa: top 3 bits of FP32 mantissa, with rounding
|
||||
e4m3_mant = scale_mant >> 20 # top 3 bits
|
||||
round_bit = (scale_mant >> 19) & 1 # 4th bit for rounding
|
||||
e4m3_mant = scale_mant >> 20
|
||||
round_bit = (scale_mant >> 19) & 1
|
||||
e4m3_mant = e4m3_mant + round_bit
|
||||
# If mantissa overflowed (8), carry into exponent
|
||||
overflow = e4m3_mant >= 8
|
||||
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
|
||||
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
|
||||
e4m3_exp = tl.minimum(e4m3_exp, 15)
|
||||
# Pack: UE4M3 = (exp << 3) | mant (no sign bit)
|
||||
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant # uint8 representation
|
||||
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant
|
||||
|
||||
# Reconstruct the dequantized scale for FP8 activation quantization
|
||||
# UE4M3 dequant: if exp == 0: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6
|
||||
# else: value = (1 + mant/8) * 2^(exp-7)
|
||||
# Use bit manipulation: 2^n = (n+127) << 23 reinterpreted as float32
|
||||
# Reconstruct dequantized scale for E2M1 quantization
|
||||
e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126)
|
||||
# For exp=0 (subnormal): already handled separately
|
||||
# For normal: (1 + mant/8) * 2^(exp-7)
|
||||
two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23
|
||||
two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True)
|
||||
normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp
|
||||
subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625 # 2^-6
|
||||
subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625
|
||||
e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value)
|
||||
|
||||
hidden_groups_f = tl.reshape(hidden, [num_groups, GROUP_K])
|
||||
scaled = hidden_groups_f * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
|
||||
scaled = tl.reshape(scaled, [BLOCK_K])
|
||||
fp8 = scaled.to(tl.float8e4nv)
|
||||
# ---- E2M1 FP4 quantization ----
|
||||
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
# Nearest-neighbor using thresholds (midpoints between consecutive values)
|
||||
scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
|
||||
# Clamp to E2M1 range [-6, 6]
|
||||
scaled = tl.maximum(scaled, -6.0)
|
||||
scaled = tl.minimum(scaled, 6.0)
|
||||
|
||||
abs_s = tl.abs(scaled)
|
||||
# Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
# [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF]
|
||||
e2m1_idx = tl.where(abs_s < 0.25, 0,
|
||||
tl.where(abs_s < 0.75, 1,
|
||||
tl.where(abs_s < 1.25, 2,
|
||||
tl.where(abs_s < 1.75, 3,
|
||||
tl.where(abs_s < 2.5, 4,
|
||||
tl.where(abs_s < 3.5, 5,
|
||||
tl.where(abs_s < 5.0, 6, 7)))))))
|
||||
sign_bit = (scaled < 0).to(tl.int32)
|
||||
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index
|
||||
|
||||
# Pack 2 E2M1 values per byte: even→low nibble, odd→high nibble
|
||||
PACKED_K: tl.constexpr = BLOCK_K // 2 # 64
|
||||
e2m1_4bit_flat = tl.reshape(e2m1_4bit, [BLOCK_K])
|
||||
# Interleave: index 0,2,4,... → low nibbles; 1,3,5,... → high nibbles
|
||||
even = e2m1_4bit_flat[0::2]
|
||||
odd = e2m1_4bit_flat[1::2]
|
||||
packed_byte = (odd.to(tl.uint8) << 4) | even.to(tl.uint8)
|
||||
|
||||
packed_k_offsets = k_block_id * PACKED_K + tl.arange(0, PACKED_K)
|
||||
packed_k_mask = packed_k_offsets < (hidden_size // 2)
|
||||
tl.store(
|
||||
x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k,
|
||||
fp8,
|
||||
mask=k_mask,
|
||||
x_fp4 + token_id * x_stride_m + packed_k_offsets * x_stride_k,
|
||||
packed_byte,
|
||||
mask=packed_k_mask,
|
||||
)
|
||||
|
||||
# Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
|
||||
@@ -378,7 +403,7 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
tl.store(
|
||||
topk_weights_out
|
||||
+ token_id * topk_weights_out_stride_m
|
||||
+ topk_offsets * topk_weights_out_stride_k,
|
||||
+ topk_weights_out_stride_k,
|
||||
weights,
|
||||
mask=topk_mask,
|
||||
)
|
||||
@@ -388,8 +413,8 @@ def _stage_deepseek_v4_mega_moe_inputs(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
x_fp8: torch.Tensor,
|
||||
x_sf: torch.Tensor,
|
||||
x_fp4: torch.Tensor, # uint8, shape (M, K//2)
|
||||
x_sf: torch.Tensor, # int32, shape (M, K//64)
|
||||
topk_idx_out: torch.Tensor,
|
||||
topk_weights_out: torch.Tensor,
|
||||
) -> None:
|
||||
@@ -413,7 +438,7 @@ def _stage_deepseek_v4_mega_moe_inputs(
|
||||
block_topk = triton.next_power_of_2(top_k)
|
||||
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
|
||||
hidden_states,
|
||||
x_fp8,
|
||||
x_fp4,
|
||||
x_sf,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
@@ -421,8 +446,8 @@ def _stage_deepseek_v4_mega_moe_inputs(
|
||||
topk_weights_out,
|
||||
hidden_states.stride(0),
|
||||
hidden_states.stride(1),
|
||||
x_fp8.stride(0),
|
||||
x_fp8.stride(1),
|
||||
x_fp4.stride(0),
|
||||
x_fp4.stride(1),
|
||||
x_sf.stride(0),
|
||||
x_sf.stride(1),
|
||||
topk_ids.stride(0),
|
||||
|
||||
220
patches/staging_kernel.py
Normal file
220
patches/staging_kernel.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales.
|
||||
|
||||
The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed).
|
||||
This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales.
|
||||
"""
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
hidden_states,
|
||||
x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte
|
||||
x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
topk_idx_out,
|
||||
topk_weights_out,
|
||||
hidden_stride_m: tl.constexpr,
|
||||
hidden_stride_k: tl.constexpr,
|
||||
x_stride_m: tl.constexpr,
|
||||
x_stride_k: tl.constexpr,
|
||||
x_sf_stride_m: tl.constexpr,
|
||||
x_sf_stride_k: tl.constexpr,
|
||||
topk_ids_stride_m: tl.constexpr,
|
||||
topk_ids_stride_k: tl.constexpr,
|
||||
topk_weights_stride_m: tl.constexpr,
|
||||
topk_weights_stride_k: tl.constexpr,
|
||||
topk_idx_stride_m: tl.constexpr,
|
||||
topk_idx_stride_k: tl.constexpr,
|
||||
topk_weights_out_stride_m: tl.constexpr,
|
||||
topk_weights_out_stride_k: tl.constexpr,
|
||||
hidden_size: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden)
|
||||
GROUP_K: tl.constexpr, # 16 (NVFP4 group_size)
|
||||
BLOCK_TOPK: tl.constexpr,
|
||||
) -> None:
|
||||
token_id = tl.program_id(0)
|
||||
k_block_id = tl.program_id(1)
|
||||
|
||||
k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
k_mask = k_offsets < hidden_size
|
||||
hidden = tl.load(
|
||||
hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k,
|
||||
mask=k_mask,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8
|
||||
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
|
||||
abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
|
||||
amax = tl.max(abs_groups, axis=1)
|
||||
amax = tl.maximum(amax, 1.0e-4)
|
||||
|
||||
# ---- UE4M3 scale computation ----
|
||||
# scale = amax / 6.0 (E2M1 max value = 6)
|
||||
# Then quantize scale to UE4M3 format
|
||||
scale = amax / 6.0
|
||||
scale_bits = scale.to(tl.uint32, bitcast=True)
|
||||
scale_exp = (scale_bits >> 23) & 0xFF
|
||||
scale_mant = scale_bits & 0x7FFFFF
|
||||
|
||||
# Convert FP32 → E4M3 manually
|
||||
e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7
|
||||
e4m3_exp = tl.maximum(e4m3_exp, 0)
|
||||
e4m3_exp = tl.minimum(e4m3_exp, 15)
|
||||
e4m3_mant = scale_mant >> 20
|
||||
round_bit = (scale_mant >> 19) & 1
|
||||
e4m3_mant = e4m3_mant + round_bit
|
||||
overflow = e4m3_mant >= 8
|
||||
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
|
||||
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
|
||||
e4m3_exp = tl.minimum(e4m3_exp, 15)
|
||||
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant
|
||||
|
||||
# Reconstruct dequantized scale for E2M1 quantization
|
||||
e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126)
|
||||
two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23
|
||||
two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True)
|
||||
normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp
|
||||
subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625
|
||||
e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value)
|
||||
|
||||
# ---- E2M1 FP4 quantization ----
|
||||
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
# Nearest-neighbor using thresholds (midpoints between consecutive values)
|
||||
scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
|
||||
# Clamp to E2M1 range [-6, 6]
|
||||
scaled = tl.maximum(scaled, -6.0)
|
||||
scaled = tl.minimum(scaled, 6.0)
|
||||
|
||||
abs_s = tl.abs(scaled)
|
||||
# Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
# [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF]
|
||||
e2m1_idx = tl.where(abs_s < 0.25, 0,
|
||||
tl.where(abs_s < 0.75, 1,
|
||||
tl.where(abs_s < 1.25, 2,
|
||||
tl.where(abs_s < 1.75, 3,
|
||||
tl.where(abs_s < 2.5, 4,
|
||||
tl.where(abs_s < 3.5, 5,
|
||||
tl.where(abs_s < 5.0, 6, 7)))))))
|
||||
sign_bit = (scaled < 0).to(tl.int32)
|
||||
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index
|
||||
|
||||
# Pack 2 E2M1 values per byte: even→low nibble, odd→high nibble
|
||||
PACKED_K: tl.constexpr = BLOCK_K // 2 # 64
|
||||
e2m1_4bit_flat = tl.reshape(e2m1_4bit, [BLOCK_K])
|
||||
# Interleave: index 0,2,4,... → low nibbles; 1,3,5,... → high nibbles
|
||||
even = e2m1_4bit_flat[0::2]
|
||||
odd = e2m1_4bit_flat[1::2]
|
||||
packed_byte = (odd.to(tl.uint8) << 4) | even.to(tl.uint8)
|
||||
|
||||
packed_k_offsets = k_block_id * PACKED_K + tl.arange(0, PACKED_K)
|
||||
packed_k_mask = packed_k_offsets < (hidden_size // 2)
|
||||
tl.store(
|
||||
x_fp4 + token_id * x_stride_m + packed_k_offsets * x_stride_k,
|
||||
packed_byte,
|
||||
mask=packed_k_mask,
|
||||
)
|
||||
|
||||
# Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
|
||||
scale_offsets = tl.arange(0, num_groups)
|
||||
packed_scale = tl.sum(scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), axis=0).to(tl.int32)
|
||||
tl.store(
|
||||
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
|
||||
packed_scale,
|
||||
)
|
||||
|
||||
if k_block_id == 0:
|
||||
topk_offsets = tl.arange(0, BLOCK_TOPK)
|
||||
topk_mask = topk_offsets < top_k
|
||||
|
||||
ids = tl.load(
|
||||
topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k,
|
||||
mask=topk_mask,
|
||||
other=0,
|
||||
).to(tl.int64)
|
||||
tl.store(
|
||||
topk_idx_out
|
||||
+ token_id * topk_idx_stride_m
|
||||
+ topk_offsets * topk_idx_stride_k,
|
||||
ids,
|
||||
mask=topk_mask,
|
||||
)
|
||||
|
||||
weights = tl.load(
|
||||
topk_weights
|
||||
+ token_id * topk_weights_stride_m
|
||||
+ topk_offsets * topk_weights_stride_k,
|
||||
mask=topk_mask,
|
||||
other=0.0,
|
||||
)
|
||||
tl.store(
|
||||
topk_weights_out
|
||||
+ token_id * topk_weights_out_stride_m
|
||||
+ topk_weights_out_stride_k,
|
||||
weights,
|
||||
mask=topk_mask,
|
||||
)
|
||||
|
||||
|
||||
def _stage_deepseek_v4_mega_moe_inputs(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
x_fp4: torch.Tensor, # uint8, shape (M, K//2)
|
||||
x_sf: torch.Tensor, # int32, shape (M, K//64)
|
||||
topk_idx_out: torch.Tensor,
|
||||
topk_weights_out: torch.Tensor,
|
||||
) -> None:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
if num_tokens == 0:
|
||||
return
|
||||
if hidden_size % 128 != 0:
|
||||
raise ValueError(
|
||||
"DeepSeek V4 MegaMoE input staging requires hidden_size to be "
|
||||
"a multiple of 128."
|
||||
)
|
||||
top_k = topk_ids.shape[1]
|
||||
if topk_weights.shape != topk_ids.shape:
|
||||
raise ValueError(
|
||||
"DeepSeek V4 MegaMoE input staging requires topk_weights and "
|
||||
"topk_ids to have the same shape."
|
||||
)
|
||||
|
||||
block_k = 128
|
||||
grid = (num_tokens, triton.cdiv(hidden_size, block_k))
|
||||
block_topk = triton.next_power_of_2(top_k)
|
||||
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
|
||||
hidden_states,
|
||||
x_fp4,
|
||||
x_sf,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
topk_idx_out,
|
||||
topk_weights_out,
|
||||
hidden_states.stride(0),
|
||||
hidden_states.stride(1),
|
||||
x_fp4.stride(0),
|
||||
x_fp4.stride(1),
|
||||
x_sf.stride(0),
|
||||
x_sf.stride(1),
|
||||
topk_ids.stride(0),
|
||||
topk_ids.stride(1),
|
||||
topk_weights.stride(0),
|
||||
topk_weights.stride(1),
|
||||
topk_idx_out.stride(0),
|
||||
topk_idx_out.stride(1),
|
||||
topk_weights_out.stride(0),
|
||||
topk_weights_out.stride(1),
|
||||
hidden_size,
|
||||
top_k,
|
||||
BLOCK_K=block_k,
|
||||
GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X)
|
||||
BLOCK_TOPK=block_topk,
|
||||
num_warps=4,
|
||||
)
|
||||
Reference in New Issue
Block a user