From 1c39e21d873c2de5fbaaec946a2ff32f014aa52e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 14 May 2026 13:05:45 +0000 Subject: [PATCH] fix: remove broken L1 weight interleave The interleave assumed gate/up were pre-interleaved in groups of 16 and that we needed 2CTA UMMA layout. Both wrong: 1. vLLM w13_weight is plain concat [gate; up] along output dim 2. Our CUTLASS kernel uses ClusterShape 1x1x1, not 2CTA The interleave was shuffling weights into nonsense, making L1 GEMM compute the wrong thing, and chunk(2) would split wrong halves. --- src/nvfp4_megamoe_kernel/weight_transform.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/weight_transform.py b/src/nvfp4_megamoe_kernel/weight_transform.py index 1d61f9bc..ed4fd6a5 100644 --- a/src/nvfp4_megamoe_kernel/weight_transform.py +++ b/src/nvfp4_megamoe_kernel/weight_transform.py @@ -59,17 +59,6 @@ def _pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor: return packed.contiguous() -def _interleave_l1_weights(weight: torch.Tensor) -> torch.Tensor: - """Interleave L1 (gate_up) weights for 2CTA UMMA.""" - E, N, K_half = weight.shape - assert N % 16 == 0, f"N={N} not divisible by 16" - w = weight.view(E, N // 16, 16, K_half) - w_gate = w[:, :, :8, :] - w_up = w[:, :, 8:16, :] - interleaved = torch.stack([w_gate, w_up], dim=3).reshape(E, N, K_half) - return interleaved.contiguous() - - def transform_nvfp4_weights_for_mega_moe( l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale) l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale) @@ -108,8 +97,9 @@ def transform_nvfp4_weights_for_mega_moe( l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous() l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous() - # Interleave L1 weights (gate_up for 2CTA UMMA) - l1_weight_out = _interleave_l1_weights(l1_weight) + # L1 weights: plain concat [gate; up] — no interleave needed + # (Our CUTLASS kernel uses 1x1x1 ClusterShape, not 2CTA) + l1_weight_out = l1_weight.contiguous() l2_weight_out = l2_weight.contiguous() return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)