feat: fold weight_scale_2 into block scales in NVFP4 transform
- transform_nvfp4_weights_for_mega_moe now accepts weight_scale_2 - Folds global scale into block scales: UE4M3 * FP32 -> UE4M3 - Dequantize to f32, multiply by global scale, clamp [0,448], re-quantize - This is needed because the kernel only applies one level of block scaling
This commit is contained in:
@@ -138,7 +138,9 @@ def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
def transform_nvfp4_weights_for_mega_moe(
|
||||
l1_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
l2_weights: Tuple[torch.Tensor, torch.Tensor]
|
||||
l2_weights: Tuple[torch.Tensor, torch.Tensor],
|
||||
l1_weight_scale_2: Optional[torch.Tensor] = None,
|
||||
l2_weight_scale_2: Optional[torch.Tensor] = None
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Transform NVFP4 expert weights for the mega_moe kernel.
|
||||
|
||||
@@ -147,13 +149,36 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
- scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16)
|
||||
|
||||
The kernel expects (weight, packed_sf) where packed_sf is int32 UTCCP layout.
|
||||
|
||||
If weight_scale_2 (float32 global scale) is provided, it is folded into the
|
||||
block scales: effective_scale = block_scale * global_scale → re-quantized to UE4M3.
|
||||
This is needed because the kernel only applies one level of scaling.
|
||||
"""
|
||||
def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
"""Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3"""
|
||||
if scale_2 is None:
|
||||
return sf
|
||||
# Dequantize UE4M3 → float32
|
||||
sf_f32 = sf.to(torch.float32)
|
||||
# Multiply by global scale
|
||||
# scale_2 shape: (num_experts, 1, 1) or (num_experts,) — broadcast over mn and K
|
||||
if scale_2.dim() == 1:
|
||||
scale_2 = scale_2.view(-1, 1, 1)
|
||||
sf_f32 = sf_f32 * scale_2
|
||||
# Re-quantize to UE4M3 (clamp to [0, 448])
|
||||
sf_f32 = sf_f32.clamp(0.0, 448.0)
|
||||
return sf_f32.to(torch.float8_e4m3fn)
|
||||
|
||||
# Fold global scales into block scales
|
||||
l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2)
|
||||
l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2)
|
||||
|
||||
# L1: interleave gate/up, then pack + transpose SF for UTCCP
|
||||
l1_interleaved = _interleave_l1_weights(l1_weights)
|
||||
l1_weights = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1]))
|
||||
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf))
|
||||
l1_out = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1]))
|
||||
# L2: only pack + transpose SF for UTCCP
|
||||
l2_weights = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_weights[1]))
|
||||
return l1_weights, l2_weights
|
||||
l2_out = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_sf))
|
||||
return l1_out, l2_out
|
||||
|
||||
|
||||
def fp8_fp4_mega_moe(y: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user