diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index ced23f8..ed519b2 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -144,41 +144,42 @@ def transform_nvfp4_weights_for_mega_moe( ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: """Transform NVFP4 expert weights for the mega_moe kernel. - NVFP4 weights come as (weight, scale) where: - - weight: uint8 E2M1 packed, shape (num_experts, N, K//2) - - scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16) - - The kernel expects (weight, packed_sf) where packed_sf is int32 UTCCP layout. - - If weight_scale_2 (float32 global scale) is provided, it is folded into the - block scales: effective_scale = block_scale * global_scale → re-quantized to UE4M3. - This is needed because the kernel only applies one level of scaling. + Uses deep_gemm.transform_sf_into_required_layout for proper TMA-aligned + UTCCP layout with recipe (1, 1, 16) for NVFP4 group_size=16. """ + from deep_gemm import transform_sf_into_required_layout + def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor: - """Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3""" if scale_2 is None: return sf - # Dequantize UE4M3 → float32 sf_f32 = sf.to(torch.float32) - # Multiply by global scale - # scale_2 shape: (num_experts, 1, 1) or (num_experts,) — broadcast over mn and K if scale_2.dim() == 1: scale_2 = scale_2.view(-1, 1, 1) sf_f32 = sf_f32 * scale_2 - # Re-quantize to UE4M3 (clamp to [0, 448]) sf_f32 = sf_f32.clamp(0.0, 448.0) return sf_f32.to(torch.float8_e4m3fn) - # Fold global scales into block scales l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2) l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2) - # L1: interleave gate/up, then pack + transpose SF for UTCCP - l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf)) - # DeepGEMM expects int8 (kPackedFP4 = torch.kInt8), but NVFP4 weights are uint8 - l1_out = (l1_interleaved[0].view(torch.int8), _pack_nvfp4_sf_for_utccp(l1_interleaved[1])) - # L2: only pack + transpose SF for UTCCP - l2_out = (l2_weights[0].view(torch.int8), _pack_nvfp4_sf_for_utccp(l2_sf)) + num_experts = l1_weights[0].shape[0] + l1_n = l1_weights[0].shape[1] # intermediate_size * 2 + l1_k = l1_weights[0].shape[2] * 2 # K (weight is K//2 uint8) + l2_n = l2_weights[0].shape[1] + l2_k = l2_weights[0].shape[2] * 2 + + # Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function + # recipe (1, 1, 16): gran_mn=1, gran_k=16 + l1_sf_packed = transform_sf_into_required_layout( + l1_sf, l1_n, l1_k, (1, 16), num_experts) + l2_sf_packed = transform_sf_into_required_layout( + l2_sf, l2_n, l2_k, (1, 16), num_experts) + + # L1: interleave gate/up + l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_packed)) + # DeepGEMM expects int8 (kPackedFP4 = torch.kInt8) + l1_out = (l1_interleaved[0].view(torch.int8), l1_interleaved[1]) + l2_out = (l2_weights[0].view(torch.int8), l2_sf_packed) return l1_out, l2_out