fix: remove broken L1 weight interleave

The interleave assumed gate/up were pre-interleaved in groups of 16
and that we needed 2CTA UMMA layout. Both wrong:
1. vLLM w13_weight is plain concat [gate; up] along output dim
2. Our CUTLASS kernel uses ClusterShape 1x1x1, not 2CTA

The interleave was shuffling weights into nonsense, making L1 GEMM
compute the wrong thing, and chunk(2) would split wrong halves.
This commit is contained in:
2026-05-14 13:05:45 +00:00
parent 80495c0cd6
commit 1c39e21d87

View File

@@ -59,17 +59,6 @@ def _pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor:
return packed.contiguous()
def _interleave_l1_weights(weight: torch.Tensor) -> torch.Tensor:
"""Interleave L1 (gate_up) weights for 2CTA UMMA."""
E, N, K_half = weight.shape
assert N % 16 == 0, f"N={N} not divisible by 16"
w = weight.view(E, N // 16, 16, K_half)
w_gate = w[:, :, :8, :]
w_up = w[:, :, 8:16, :]
interleaved = torch.stack([w_gate, w_up], dim=3).reshape(E, N, K_half)
return interleaved.contiguous()
def transform_nvfp4_weights_for_mega_moe(
l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
@@ -108,8 +97,9 @@ def transform_nvfp4_weights_for_mega_moe(
l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
# Interleave L1 weights (gate_up for 2CTA UMMA)
l1_weight_out = _interleave_l1_weights(l1_weight)
# L1 weights: plain concat [gate; up] — no interleave needed
# (Our CUTLASS kernel uses 1x1x1 ClusterShape, not 2CTA)
l1_weight_out = l1_weight.contiguous()
l2_weight_out = l2_weight.contiguous()
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)