diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index f4cd0f4..9bbc65f 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -172,7 +172,6 @@ def transform_nvfp4_weights_for_mega_moe( # 4 UE4M3 bytes → 1 int32, matching the hardware's 4X scale vector def pack_ue4m3_to_int32(sf): sf_u8 = sf.view(torch.uint8) - # Pack 4 consecutive uint8 bytes into int32 assert sf_u8.shape[-1] % 4 == 0 packed = (sf_u8[..., 0::4].to(torch.int32) | (sf_u8[..., 1::4].to(torch.int32) << 8) | @@ -183,15 +182,18 @@ def transform_nvfp4_weights_for_mega_moe( l1_sf_packed = pack_ue4m3_to_int32(l1_sf) l2_sf_packed = pack_ue4m3_to_int32(l2_sf) + # Reshape to 2D for transform_sf_into_required_layout + # (experts, mn, K//64) → (experts * mn, K//64) + # The C++ function expects 2D or properly-strided 3D tensors + l1_sf_2d = l1_sf_packed.reshape(-1, l1_sf_packed.shape[-1]) + l2_sf_2d = l2_sf_packed.reshape(-1, l2_sf_packed.shape[-1]) + # Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function - # Pass as kInt with recipe (1, 16): gran_mn=1, gran_k=16 - # After packing, effective K for SF is k/4 (4 UE4M3 per int32) - # check_sf_layout expects: sf.size(-1) = ceil_div(k, gran_k * 4) = ceil_div(k, 64) - # Our packed shape is (experts, mn, K/64) — matches! + # recipe (1, 16): gran_mn=1, gran_k=16 l1_sf_transformed = transform_sf_into_required_layout( - l1_sf_packed, l1_n, l1_k, (1, 16), num_experts) + l1_sf_2d, l1_n, l1_k, (1, 16), num_experts) l2_sf_transformed = transform_sf_into_required_layout( - l2_sf_packed, l2_n, l2_k, (1, 16), num_experts) + l2_sf_2d, l2_n, l2_k, (1, 16), num_experts) # L1: interleave gate/up l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_packed))