The interleave assumed gate/up were pre-interleaved in groups of 16 and that we needed 2CTA UMMA layout. Both wrong: 1. vLLM w13_weight is plain concat [gate; up] along output dim 2. Our CUTLASS kernel uses ClusterShape 1x1x1, not 2CTA The interleave was shuffling weights into nonsense, making L1 GEMM compute the wrong thing, and chunk(2) would split wrong halves.
106 lines
4.1 KiB
Python
106 lines
4.1 KiB
Python
"""
|
|
NVFP4 Weight Transformation for CUTLASS mega_moe kernel.
|
|
|
|
Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float32 global scale)
|
|
into the format expected by the CUTLASS block-scaled GEMM kernel:
|
|
- Packed FP4 weights (int8, K-major)
|
|
- UE4M3 block scales (float8_e4m3fn, row-major — CUTLASS SF remap handles layout)
|
|
|
|
This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe.
|
|
|
|
Call signature matches the nightly vLLM deepseek_v4.py finalize_weights:
|
|
transform_nvfp4_weights_for_mega_moe(
|
|
(l1_weight, l1_weight_scale),
|
|
(l2_weight, l2_weight_scale),
|
|
l1_weight_scale_2=...,
|
|
l2_weight_scale_2=...,
|
|
)
|
|
"""
|
|
|
|
import torch
|
|
|
|
|
|
def _fold_global_scale(
|
|
weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn
|
|
weight_scale_2: torch.Tensor, # (E, num_logical) or (E, 1) or scalar float32
|
|
logical_widths: list[int] = None,
|
|
) -> torch.Tensor:
|
|
"""Fold global scale into block scales: UE4M3 * FP32 → float32."""
|
|
sf_f32 = weight_scale.to(torch.float32)
|
|
gs = weight_scale_2.to(torch.float32)
|
|
|
|
if gs.numel() == 1:
|
|
sf_f32 = sf_f32 * gs
|
|
elif gs.numel() > 1 and logical_widths is not None:
|
|
expanded = []
|
|
for i, w in enumerate(logical_widths):
|
|
if i < len(gs.flatten()):
|
|
expanded.append(gs.flatten()[i].expand(w))
|
|
if expanded:
|
|
global_scale = torch.cat(expanded).unsqueeze(1)
|
|
sf_f32 = sf_f32 * global_scale
|
|
else:
|
|
# Per-expert or per-row global scale — broadcast multiply
|
|
while gs.dim() < sf_f32.dim():
|
|
gs = gs.unsqueeze(-1)
|
|
sf_f32 = sf_f32 * gs.expand_as(sf_f32)
|
|
|
|
return sf_f32
|
|
|
|
|
|
def _pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor:
|
|
"""Pack 4 UE4M3 (float8_e4m3fn) values into one uint32."""
|
|
sf_u8 = sf.view(torch.uint8)
|
|
assert sf_u8.shape[-1] % 4 == 0, f"Last dim {sf_u8.shape[-1]} not divisible by 4"
|
|
packed = (sf_u8[..., 0::4].to(torch.int32) |
|
|
(sf_u8[..., 1::4].to(torch.int32) << 8) |
|
|
(sf_u8[..., 2::4].to(torch.int32) << 16) |
|
|
(sf_u8[..., 3::4].to(torch.int32) << 24))
|
|
return packed.contiguous()
|
|
|
|
|
|
def transform_nvfp4_weights_for_mega_moe(
|
|
l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
|
|
l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
|
|
l1_weight_scale_2: torch.Tensor = None, # float32 global scale for L1
|
|
l2_weight_scale_2: torch.Tensor = None, # float32 global scale for L2
|
|
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
|
"""Transform NVFP4 weights for the CUTLASS block-scaled GEMM.
|
|
|
|
Matches the call signature from nightly vLLM deepseek_v4.py finalize_weights.
|
|
|
|
Args:
|
|
l1_tuple: (w13_weight, w13_weight_scale) — gate_up proj
|
|
l2_tuple: (w2_weight, w2_weight_scale) — down proj
|
|
l1_weight_scale_2: global scale for L1 (float32)
|
|
l2_weight_scale_2: global scale for L2 (float32)
|
|
|
|
Returns:
|
|
((l1_weight, l1_sf_packed), (l2_weight, l2_sf_packed))
|
|
"""
|
|
l1_weight, l1_weight_scale = l1_tuple
|
|
l2_weight, l2_weight_scale = l2_tuple
|
|
|
|
# Fold global scales into block scales
|
|
if l1_weight_scale_2 is not None:
|
|
l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2,
|
|
logical_widths=[3072, 3072])
|
|
else:
|
|
l1_sf_folded = l1_weight_scale.to(torch.float32)
|
|
|
|
if l2_weight_scale_2 is not None:
|
|
l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2)
|
|
else:
|
|
l2_sf_folded = l2_weight_scale.to(torch.float32)
|
|
|
|
# Clamp and convert back to UE4M3
|
|
l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
|
|
l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
|
|
|
|
# L1 weights: plain concat [gate; up] — no interleave needed
|
|
# (Our CUTLASS kernel uses 1x1x1 ClusterShape, not 2CTA)
|
|
l1_weight_out = l1_weight.contiguous()
|
|
l2_weight_out = l2_weight.contiguous()
|
|
|
|
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)
|