diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 52fbb35..4178dd7 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -138,7 +138,9 @@ def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: def transform_nvfp4_weights_for_mega_moe( l1_weights: Tuple[torch.Tensor, torch.Tensor], - l2_weights: Tuple[torch.Tensor, torch.Tensor] + l2_weights: Tuple[torch.Tensor, torch.Tensor], + l1_weight_scale_2: Optional[torch.Tensor] = None, + l2_weight_scale_2: Optional[torch.Tensor] = None ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: """Transform NVFP4 expert weights for the mega_moe kernel. @@ -147,13 +149,36 @@ def transform_nvfp4_weights_for_mega_moe( - scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16) The kernel expects (weight, packed_sf) where packed_sf is int32 UTCCP layout. + + If weight_scale_2 (float32 global scale) is provided, it is folded into the + block scales: effective_scale = block_scale * global_scale → re-quantized to UE4M3. + This is needed because the kernel only applies one level of scaling. """ + def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor: + """Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3""" + if scale_2 is None: + return sf + # Dequantize UE4M3 → float32 + sf_f32 = sf.to(torch.float32) + # Multiply by global scale + # scale_2 shape: (num_experts, 1, 1) or (num_experts,) — broadcast over mn and K + if scale_2.dim() == 1: + scale_2 = scale_2.view(-1, 1, 1) + sf_f32 = sf_f32 * scale_2 + # Re-quantize to UE4M3 (clamp to [0, 448]) + sf_f32 = sf_f32.clamp(0.0, 448.0) + return sf_f32.to(torch.float8_e4m3fn) + + # Fold global scales into block scales + l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2) + l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2) + # L1: interleave gate/up, then pack + transpose SF for UTCCP - l1_interleaved = _interleave_l1_weights(l1_weights) - l1_weights = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1])) + l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf)) + l1_out = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1])) # L2: only pack + transpose SF for UTCCP - l2_weights = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_weights[1])) - return l1_weights, l2_weights + l2_out = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_sf)) + return l1_out, l2_out def fp8_fp4_mega_moe(y: torch.Tensor,