feat: fold weight_scale_2 into block scales in NVFP4 transform

- transform_nvfp4_weights_for_mega_moe now accepts weight_scale_2
- Folds global scale into block scales: UE4M3 * FP32 -> UE4M3
- Dequantize to f32, multiply by global scale, clamp [0,448], re-quantize
- This is needed because the kernel only applies one level of block scaling
This commit is contained in:
2026-05-11 05:42:16 +00:00
parent 42c215d49b
commit bbf9a5f46a

View File

@@ -138,7 +138,9 @@ def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
def transform_nvfp4_weights_for_mega_moe(
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor]
l2_weights: Tuple[torch.Tensor, torch.Tensor],
l1_weight_scale_2: Optional[torch.Tensor] = None,
l2_weight_scale_2: Optional[torch.Tensor] = None
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
"""Transform NVFP4 expert weights for the mega_moe kernel.
@@ -147,13 +149,36 @@ def transform_nvfp4_weights_for_mega_moe(
- scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16)
The kernel expects (weight, packed_sf) where packed_sf is int32 UTCCP layout.
If weight_scale_2 (float32 global scale) is provided, it is folded into the
block scales: effective_scale = block_scale * global_scale → re-quantized to UE4M3.
This is needed because the kernel only applies one level of scaling.
"""
def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor:
"""Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3"""
if scale_2 is None:
return sf
# Dequantize UE4M3 → float32
sf_f32 = sf.to(torch.float32)
# Multiply by global scale
# scale_2 shape: (num_experts, 1, 1) or (num_experts,) — broadcast over mn and K
if scale_2.dim() == 1:
scale_2 = scale_2.view(-1, 1, 1)
sf_f32 = sf_f32 * scale_2
# Re-quantize to UE4M3 (clamp to [0, 448])
sf_f32 = sf_f32.clamp(0.0, 448.0)
return sf_f32.to(torch.float8_e4m3fn)
# Fold global scales into block scales
l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2)
l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2)
# L1: interleave gate/up, then pack + transpose SF for UTCCP
l1_interleaved = _interleave_l1_weights(l1_weights)
l1_weights = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1]))
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf))
l1_out = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1]))
# L2: only pack + transpose SF for UTCCP
l2_weights = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_weights[1]))
return l1_weights, l2_weights
l2_out = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_sf))
return l1_out, l2_out
def fp8_fp4_mega_moe(y: torch.Tensor,