Files
nvfp4-megamoe-kernel/src/nvfp4_megamoe_kernel/weight_transform.py
biondizzle 1c39e21d87 fix: remove broken L1 weight interleave
The interleave assumed gate/up were pre-interleaved in groups of 16
and that we needed 2CTA UMMA layout. Both wrong:
1. vLLM w13_weight is plain concat [gate; up] along output dim
2. Our CUTLASS kernel uses ClusterShape 1x1x1, not 2CTA

The interleave was shuffling weights into nonsense, making L1 GEMM
compute the wrong thing, and chunk(2) would split wrong halves.
2026-05-14 13:05:45 +00:00

106 lines
4.1 KiB
Python

"""
NVFP4 Weight Transformation for CUTLASS mega_moe kernel.
Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float32 global scale)
into the format expected by the CUTLASS block-scaled GEMM kernel:
- Packed FP4 weights (int8, K-major)
- UE4M3 block scales (float8_e4m3fn, row-major — CUTLASS SF remap handles layout)
This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe.
Call signature matches the nightly vLLM deepseek_v4.py finalize_weights:
transform_nvfp4_weights_for_mega_moe(
(l1_weight, l1_weight_scale),
(l2_weight, l2_weight_scale),
l1_weight_scale_2=...,
l2_weight_scale_2=...,
)
"""
import torch
def _fold_global_scale(
weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn
weight_scale_2: torch.Tensor, # (E, num_logical) or (E, 1) or scalar float32
logical_widths: list[int] = None,
) -> torch.Tensor:
"""Fold global scale into block scales: UE4M3 * FP32 → float32."""
sf_f32 = weight_scale.to(torch.float32)
gs = weight_scale_2.to(torch.float32)
if gs.numel() == 1:
sf_f32 = sf_f32 * gs
elif gs.numel() > 1 and logical_widths is not None:
expanded = []
for i, w in enumerate(logical_widths):
if i < len(gs.flatten()):
expanded.append(gs.flatten()[i].expand(w))
if expanded:
global_scale = torch.cat(expanded).unsqueeze(1)
sf_f32 = sf_f32 * global_scale
else:
# Per-expert or per-row global scale — broadcast multiply
while gs.dim() < sf_f32.dim():
gs = gs.unsqueeze(-1)
sf_f32 = sf_f32 * gs.expand_as(sf_f32)
return sf_f32
def _pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor:
"""Pack 4 UE4M3 (float8_e4m3fn) values into one uint32."""
sf_u8 = sf.view(torch.uint8)
assert sf_u8.shape[-1] % 4 == 0, f"Last dim {sf_u8.shape[-1]} not divisible by 4"
packed = (sf_u8[..., 0::4].to(torch.int32) |
(sf_u8[..., 1::4].to(torch.int32) << 8) |
(sf_u8[..., 2::4].to(torch.int32) << 16) |
(sf_u8[..., 3::4].to(torch.int32) << 24))
return packed.contiguous()
def transform_nvfp4_weights_for_mega_moe(
l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
l1_weight_scale_2: torch.Tensor = None, # float32 global scale for L1
l2_weight_scale_2: torch.Tensor = None, # float32 global scale for L2
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
"""Transform NVFP4 weights for the CUTLASS block-scaled GEMM.
Matches the call signature from nightly vLLM deepseek_v4.py finalize_weights.
Args:
l1_tuple: (w13_weight, w13_weight_scale) — gate_up proj
l2_tuple: (w2_weight, w2_weight_scale) — down proj
l1_weight_scale_2: global scale for L1 (float32)
l2_weight_scale_2: global scale for L2 (float32)
Returns:
((l1_weight, l1_sf_packed), (l2_weight, l2_sf_packed))
"""
l1_weight, l1_weight_scale = l1_tuple
l2_weight, l2_weight_scale = l2_tuple
# Fold global scales into block scales
if l1_weight_scale_2 is not None:
l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2,
logical_widths=[3072, 3072])
else:
l1_sf_folded = l1_weight_scale.to(torch.float32)
if l2_weight_scale_2 is not None:
l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2)
else:
l2_sf_folded = l2_weight_scale.to(torch.float32)
# Clamp and convert back to UE4M3
l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
# L1 weights: plain concat [gate; up] — no interleave needed
# (Our CUTLASS kernel uses 1x1x1 ClusterShape, not 2CTA)
l1_weight_out = l1_weight.contiguous()
l2_weight_out = l2_weight.contiguous()
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)