diff --git a/src/nvfp4_megamoe_kernel/weight_transform.py b/src/nvfp4_megamoe_kernel/weight_transform.py index 1d61f9bc..ed4fd6a5 100644 --- a/src/nvfp4_megamoe_kernel/weight_transform.py +++ b/src/nvfp4_megamoe_kernel/weight_transform.py @@ -59,17 +59,6 @@ def _pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor: return packed.contiguous() -def _interleave_l1_weights(weight: torch.Tensor) -> torch.Tensor: - """Interleave L1 (gate_up) weights for 2CTA UMMA.""" - E, N, K_half = weight.shape - assert N % 16 == 0, f"N={N} not divisible by 16" - w = weight.view(E, N // 16, 16, K_half) - w_gate = w[:, :, :8, :] - w_up = w[:, :, 8:16, :] - interleaved = torch.stack([w_gate, w_up], dim=3).reshape(E, N, K_half) - return interleaved.contiguous() - - def transform_nvfp4_weights_for_mega_moe( l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale) l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale) @@ -108,8 +97,9 @@ def transform_nvfp4_weights_for_mega_moe( l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous() l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous() - # Interleave L1 weights (gate_up for 2CTA UMMA) - l1_weight_out = _interleave_l1_weights(l1_weight) + # L1 weights: plain concat [gate; up] — no interleave needed + # (Our CUTLASS kernel uses 1x1x1 ClusterShape, not 2CTA) + l1_weight_out = l1_weight.contiguous() l2_weight_out = l2_weight.contiguous() return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)