fix: UE4M3 activation scales + group_size=16 for NVFP4 mega_moe
The mxf4nvf4 MMA instruction shares scale_format_ between SFA and SFB. For NVFP4 (UE4M3), both activation and weight scales must be UE4M3. Changes to _stage_deepseek_v4_mega_moe_inputs_kernel: - GROUP_K=16 (was 32) — NVFP4 scale_vec::4X has group_size=16 - Scale quantization: float → float8_e4m3fn (UE4M3) instead of UE8M0 exponent extraction (>> 23). Pack 4 UE4M3 bytes per int32. - FP8 activation quantized against UE4M3 rounded scale Also updated class docstring (was stale MXFP4 conversion description).
This commit is contained in:
@@ -289,16 +289,32 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
amax = tl.max(hidden_groups, axis=1)
|
||||
amax = tl.maximum(amax, 1.0e-4)
|
||||
|
||||
scale = amax / 448.0
|
||||
scale_bits = scale.to(tl.uint32, bitcast=True)
|
||||
scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to(
|
||||
tl.uint32
|
||||
)
|
||||
scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254)
|
||||
rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True)
|
||||
# NVFP4: UE4M3 activation scales (not UE8M0)
|
||||
# scale_format_ is shared between SFA and SFB in the MMA descriptor,
|
||||
# so both must use the same format. For mxf4nvf4, both are UE4M3.
|
||||
scale = amax / 448.0 # UE4M3 max = 448
|
||||
# Quantize to UE4M3: float → float8_e4m3fn
|
||||
# E4M3: 1 sign + 4 exp + 3 mantissa, max normal = 448
|
||||
# For unsigned (UE4M3), we zero the sign bit
|
||||
scale_f32 = scale.to(tl.float32)
|
||||
scale_bits = scale_f32.to(tl.uint32, bitcast=True)
|
||||
# Extract FP8 E4M3 bits: clamp to [0, 448], convert
|
||||
# Simple approach: store as float32, reinterpret low 8 bits as E4M3
|
||||
# Triton doesn't have float8, so we compute E4M3 manually
|
||||
# Actually, we can use the FP8 e4m3 type in triton
|
||||
scale_e4m3 = scale.to(tl.float8e4nv) # hardware native E4M3
|
||||
scale_u8 = scale_e4m3.to(tl.uint8)
|
||||
# Clear sign bit for UE4M3
|
||||
scale_u8 = scale_u8 & 0x7F
|
||||
|
||||
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
|
||||
scaled = hidden_groups * (1.0 / rounded_scale)[:, None]
|
||||
# Dequantize the rounded scale for FP8 activation quantization
|
||||
# UE4M3 value = (1 + m/8) * 2^(e-7) for normal, or (m/8) * 2^(1-7) for subnormal
|
||||
# For simplicity, reconstruct from the rounded FP8 value
|
||||
rounded_scale = scale_e4m3.to(tl.float32) # dequantize UE4M3 → FP32
|
||||
# But we need the FP8 activation, quantized with this scale
|
||||
# Recompute: activation_fp8 = activation / rounded_scale
|
||||
hidden_groups_f = tl.reshape(hidden, [num_groups, GROUP_K])
|
||||
scaled = hidden_groups_f * (1.0 / tl.maximum(rounded_scale, 1e-6))[:, None]
|
||||
scaled = tl.reshape(scaled, [BLOCK_K])
|
||||
fp8 = scaled.to(tl.float8e4nv)
|
||||
tl.store(
|
||||
@@ -307,8 +323,9 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
mask=k_mask,
|
||||
)
|
||||
|
||||
# Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
|
||||
scale_offsets = tl.arange(0, num_groups)
|
||||
packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32)
|
||||
packed_scale = tl.sum(scale_u8.to(tl.int32) << (scale_offsets * 8), axis=0).to(tl.int32)
|
||||
tl.store(
|
||||
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
|
||||
packed_scale,
|
||||
@@ -399,7 +416,7 @@ def _stage_deepseek_v4_mega_moe_inputs(
|
||||
hidden_size,
|
||||
top_k,
|
||||
BLOCK_K=block_k,
|
||||
GROUP_K=32,
|
||||
GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X)
|
||||
BLOCK_TOPK=block_topk,
|
||||
num_warps=4,
|
||||
)
|
||||
@@ -428,14 +445,11 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
"""MegaMoE experts for DeepSeek V4 with NVFP4 quantization.
|
||||
|
||||
Loads NVFP4 expert weights (E2M1 packed uint8 + float8_e4m3fn block scales
|
||||
+ float32 global scales) and converts them to MXFP4 format for the
|
||||
DeepGEMM fp8_fp4_mega_moe kernel at finalize_weights time.
|
||||
+ float32 global scales) and feeds them natively to the DeepGEMM
|
||||
fp8_nvfp4_mega_moe kernel (kind::mxf4nvf4.scale_vec::4X).
|
||||
|
||||
NVFP4 → MXFP4 conversion:
|
||||
1. Unpack E2M1 FP4 → BF16
|
||||
2. Dequantize with UE8M0 block_scale * float32 global_scale
|
||||
3. Re-quantize BF16 → MXFP4 (E2M1 + UE8M0, group_size=32)
|
||||
4. Feed to deep_gemm transform_weights_for_mega_moe
|
||||
No conversion to MXFP4. Experts stay NVFP4. The global scale (weight_scale_2)
|
||||
is folded into the block scales before kernel consumption.
|
||||
"""
|
||||
_symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user