fix: remove broken L1 weight interleave
The interleave assumed gate/up were pre-interleaved in groups of 16 and that we needed 2CTA UMMA layout. Both wrong: 1. vLLM w13_weight is plain concat [gate; up] along output dim 2. Our CUTLASS kernel uses ClusterShape 1x1x1, not 2CTA The interleave was shuffling weights into nonsense, making L1 GEMM compute the wrong thing, and chunk(2) would split wrong halves.
This commit is contained in:
@@ -59,17 +59,6 @@ def _pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor:
|
||||
return packed.contiguous()
|
||||
|
||||
|
||||
def _interleave_l1_weights(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Interleave L1 (gate_up) weights for 2CTA UMMA."""
|
||||
E, N, K_half = weight.shape
|
||||
assert N % 16 == 0, f"N={N} not divisible by 16"
|
||||
w = weight.view(E, N // 16, 16, K_half)
|
||||
w_gate = w[:, :, :8, :]
|
||||
w_up = w[:, :, 8:16, :]
|
||||
interleaved = torch.stack([w_gate, w_up], dim=3).reshape(E, N, K_half)
|
||||
return interleaved.contiguous()
|
||||
|
||||
|
||||
def transform_nvfp4_weights_for_mega_moe(
|
||||
l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
|
||||
l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
|
||||
@@ -108,8 +97,9 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
|
||||
l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
|
||||
|
||||
# Interleave L1 weights (gate_up for 2CTA UMMA)
|
||||
l1_weight_out = _interleave_l1_weights(l1_weight)
|
||||
# L1 weights: plain concat [gate; up] — no interleave needed
|
||||
# (Our CUTLASS kernel uses 1x1x1 ClusterShape, not 2CTA)
|
||||
l1_weight_out = l1_weight.contiguous()
|
||||
l2_weight_out = l2_weight.contiguous()
|
||||
|
||||
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)
|
||||
|
||||
Reference in New Issue
Block a user