fix: UE4M3 activation scales + group_size=16 for NVFP4 mega_moe

The mxf4nvf4 MMA instruction shares scale_format_ between SFA and SFB.
For NVFP4 (UE4M3), both activation and weight scales must be UE4M3.

Changes to _stage_deepseek_v4_mega_moe_inputs_kernel:
- GROUP_K=16 (was 32) — NVFP4 scale_vec::4X has group_size=16
- Scale quantization: float → float8_e4m3fn (UE4M3) instead of UE8M0
  exponent extraction (>> 23). Pack 4 UE4M3 bytes per int32.
- FP8 activation quantized against UE4M3 rounded scale

Also updated class docstring (was stale MXFP4 conversion description).
This commit is contained in:
2026-05-11 16:12:36 +00:00
parent 220649c188
commit 5faf9916eb

View File

@@ -289,16 +289,32 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
amax = tl.max(hidden_groups, axis=1)
amax = tl.maximum(amax, 1.0e-4)
scale = amax / 448.0
scale_bits = scale.to(tl.uint32, bitcast=True)
scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to(
tl.uint32
)
scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254)
rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True)
# NVFP4: UE4M3 activation scales (not UE8M0)
# scale_format_ is shared between SFA and SFB in the MMA descriptor,
# so both must use the same format. For mxf4nvf4, both are UE4M3.
scale = amax / 448.0 # UE4M3 max = 448
# Quantize to UE4M3: float → float8_e4m3fn
# E4M3: 1 sign + 4 exp + 3 mantissa, max normal = 448
# For unsigned (UE4M3), we zero the sign bit
scale_f32 = scale.to(tl.float32)
scale_bits = scale_f32.to(tl.uint32, bitcast=True)
# Extract FP8 E4M3 bits: clamp to [0, 448], convert
# Simple approach: store as float32, reinterpret low 8 bits as E4M3
# Triton doesn't have float8, so we compute E4M3 manually
# Actually, we can use the FP8 e4m3 type in triton
scale_e4m3 = scale.to(tl.float8e4nv) # hardware native E4M3
scale_u8 = scale_e4m3.to(tl.uint8)
# Clear sign bit for UE4M3
scale_u8 = scale_u8 & 0x7F
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
scaled = hidden_groups * (1.0 / rounded_scale)[:, None]
# Dequantize the rounded scale for FP8 activation quantization
# UE4M3 value = (1 + m/8) * 2^(e-7) for normal, or (m/8) * 2^(1-7) for subnormal
# For simplicity, reconstruct from the rounded FP8 value
rounded_scale = scale_e4m3.to(tl.float32) # dequantize UE4M3 → FP32
# But we need the FP8 activation, quantized with this scale
# Recompute: activation_fp8 = activation / rounded_scale
hidden_groups_f = tl.reshape(hidden, [num_groups, GROUP_K])
scaled = hidden_groups_f * (1.0 / tl.maximum(rounded_scale, 1e-6))[:, None]
scaled = tl.reshape(scaled, [BLOCK_K])
fp8 = scaled.to(tl.float8e4nv)
tl.store(
@@ -307,8 +323,9 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
mask=k_mask,
)
# Pack 4 UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
scale_offsets = tl.arange(0, num_groups)
packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32)
packed_scale = tl.sum(scale_u8.to(tl.int32) << (scale_offsets * 8), axis=0).to(tl.int32)
tl.store(
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
packed_scale,
@@ -399,7 +416,7 @@ def _stage_deepseek_v4_mega_moe_inputs(
hidden_size,
top_k,
BLOCK_K=block_k,
GROUP_K=32,
GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X)
BLOCK_TOPK=block_topk,
num_warps=4,
)
@@ -428,14 +445,11 @@ class DeepseekV4MegaMoEExperts(nn.Module):
"""MegaMoE experts for DeepSeek V4 with NVFP4 quantization.
Loads NVFP4 expert weights (E2M1 packed uint8 + float8_e4m3fn block scales
+ float32 global scales) and converts them to MXFP4 format for the
DeepGEMM fp8_fp4_mega_moe kernel at finalize_weights time.
+ float32 global scales) and feeds them natively to the DeepGEMM
fp8_nvfp4_mega_moe kernel (kind::mxf4nvf4.scale_vec::4X).
NVFP4 → MXFP4 conversion:
1. Unpack E2M1 FP4 → BF16
2. Dequantize with UE8M0 block_scale * float32 global_scale
3. Re-quantize BF16 → MXFP4 (E2M1 + UE8M0, group_size=32)
4. Feed to deep_gemm transform_weights_for_mega_moe
No conversion to MXFP4. Experts stay NVFP4. The global scale (weight_scale_2)
is folded into the block scales before kernel consumption.
"""
_symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {}