diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index ed519b2..f4cd0f4 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -163,23 +163,41 @@ def transform_nvfp4_weights_for_mega_moe( l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2) num_experts = l1_weights[0].shape[0] - l1_n = l1_weights[0].shape[1] # intermediate_size * 2 - l1_k = l1_weights[0].shape[2] * 2 # K (weight is K//2 uint8) + l1_n = l1_weights[0].shape[1] + l1_k = l1_weights[0].shape[2] * 2 l2_n = l2_weights[0].shape[1] l2_k = l2_weights[0].shape[2] * 2 + # Pack UE4M3 (float8_e4m3fn) into int32 for DeepGEMM TMA consumption + # 4 UE4M3 bytes → 1 int32, matching the hardware's 4X scale vector + def pack_ue4m3_to_int32(sf): + sf_u8 = sf.view(torch.uint8) + # Pack 4 consecutive uint8 bytes into int32 + assert sf_u8.shape[-1] % 4 == 0 + packed = (sf_u8[..., 0::4].to(torch.int32) | + (sf_u8[..., 1::4].to(torch.int32) << 8) | + (sf_u8[..., 2::4].to(torch.int32) << 16) | + (sf_u8[..., 3::4].to(torch.int32) << 24)) + return packed.contiguous() + + l1_sf_packed = pack_ue4m3_to_int32(l1_sf) + l2_sf_packed = pack_ue4m3_to_int32(l2_sf) + # Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function - # recipe (1, 1, 16): gran_mn=1, gran_k=16 - l1_sf_packed = transform_sf_into_required_layout( - l1_sf, l1_n, l1_k, (1, 16), num_experts) - l2_sf_packed = transform_sf_into_required_layout( - l2_sf, l2_n, l2_k, (1, 16), num_experts) + # Pass as kInt with recipe (1, 16): gran_mn=1, gran_k=16 + # After packing, effective K for SF is k/4 (4 UE4M3 per int32) + # check_sf_layout expects: sf.size(-1) = ceil_div(k, gran_k * 4) = ceil_div(k, 64) + # Our packed shape is (experts, mn, K/64) — matches! + l1_sf_transformed = transform_sf_into_required_layout( + l1_sf_packed, l1_n, l1_k, (1, 16), num_experts) + l2_sf_transformed = transform_sf_into_required_layout( + l2_sf_packed, l2_n, l2_k, (1, 16), num_experts) # L1: interleave gate/up l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_packed)) # DeepGEMM expects int8 (kPackedFP4 = torch.kInt8) - l1_out = (l1_interleaved[0].view(torch.int8), l1_interleaved[1]) - l2_out = (l2_weights[0].view(torch.int8), l2_sf_packed) + l1_out = (l1_interleaved[0].view(torch.int8), l1_sf_transformed) + l2_out = (l2_weights[0].view(torch.int8), l2_sf_transformed) return l1_out, l2_out