Fix weight_transform signature to match nightly vLLM finalize_weights call

This commit is contained in:
2026-05-13 21:41:43 +00:00
parent 47ca5631d8
commit 1a452ffabd

View File

@@ -7,112 +7,117 @@ into the TMA-aligned format expected by the block-scaled GEMM kernel:
- Packed UE4M3 scales (uint32, TMA-aligned UTCCP layout)
This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe.
Call signature matches the nightly vLLM deepseek_v4.py finalize_weights:
transform_nvfp4_weights_for_mega_moe(
(l1_weight, l1_weight_scale),
(l2_weight, l2_weight_scale),
l1_weight_scale_2=...,
l2_weight_scale_2=...,
)
"""
import torch
import math
def fold_global_scale(
weight_scale: torch.Tensor, # (N, K//16) float8_e4m3fn
weight_scale_2: torch.Tensor, # (num_logical,) or scalar float32
logical_widths: list[int] = None, # per-logical-weight row counts
def _fold_global_scale(
weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn
weight_scale_2: torch.Tensor, # (E, num_logical) or (E, 1) or scalar float32
logical_widths: list[int] = None,
) -> torch.Tensor:
"""Fold global scale into block scales: UE4M3 * FP32 → UE4M3 → float32.
Returns: (N, K//16) float32 folded block scales.
"""
"""Fold global scale into block scales: UE4M3 * FP32 → float32."""
sf_f32 = weight_scale.to(torch.float32)
if weight_scale_2.numel() == 1:
sf_f32 = sf_f32 * weight_scale_2.to(torch.float32)
elif weight_scale_2.numel() > 1 and logical_widths is not None:
# Per-logical-weight global scale (e.g., gate_up_proj has 2)
gs = weight_scale_2.to(torch.float32)
if gs.numel() == 1:
sf_f32 = sf_f32 * gs
elif gs.numel() > 1 and logical_widths is not None:
expanded = []
for i, w in enumerate(logical_widths):
if i < len(weight_scale_2):
expanded.append(weight_scale_2[i].flatten()[0].expand(w))
global_scale = torch.cat(expanded).to(torch.float32).unsqueeze(1)
sf_f32 = sf_f32 * global_scale
if i < len(gs.flatten()):
expanded.append(gs.flatten()[i].expand(w))
if expanded:
global_scale = torch.cat(expanded).unsqueeze(1)
sf_f32 = sf_f32 * global_scale
else:
sf_f32 = sf_f32 * weight_scale_2.max().to(torch.float32)
# Per-expert or per-row global scale — broadcast multiply
while gs.dim() < sf_f32.dim():
gs = gs.unsqueeze(-1)
sf_f32 = sf_f32 * gs.expand_as(sf_f32)
return sf_f32
def pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor:
"""Pack 4 UE4M3 (float8_e4m3fn) values into one uint32.
Input: (..., K//16) float8_e4m3fn
Output: (..., K//64) uint32 (4 values packed per word)
"""
# View as raw bytes
sf_u8 = sf.view(torch.uint8) # (..., K//16) uint8
shape = sf_u8.shape
assert shape[-1] % 4 == 0, f"Last dim {shape[-1]} not divisible by 4"
# Pack 4 consecutive uint8 values into one uint32
def _pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor:
"""Pack 4 UE4M3 (float8_e4m3fn) values into one uint32."""
sf_u8 = sf.view(torch.uint8)
assert sf_u8.shape[-1] % 4 == 0, f"Last dim {sf_u8.shape[-1]} not divisible by 4"
packed = (sf_u8[..., 0::4].to(torch.int32) |
(sf_u8[..., 1::4].to(torch.int32) << 8) |
(sf_u8[..., 2::4].to(torch.int32) << 16) |
(sf_u8[..., 3::4].to(torch.int32) << 24))
return packed.contiguous()
def interleave_l1_weights(
weight: torch.Tensor, # (E, 2*INTER, K//2) int8, K-major
) -> torch.Tensor:
"""Interleave L1 (gate_up) weights for 2CTA UMMA.
The gate and up projections are interleaved in groups of 8 rows
to match the UMMA 2CTA schedule.
"""
def _interleave_l1_weights(weight: torch.Tensor) -> torch.Tensor:
"""Interleave L1 (gate_up) weights for 2CTA UMMA."""
E, N, K_half = weight.shape
assert N % 16 == 0, f"N={N} not divisible by 16"
# Reshape to (E, N//16, 16, K_half) → interleave pairs of 8
w = weight.view(E, N // 16, 16, K_half)
w_gate = w[:, :, :8, :]
w_up = w[:, :, 8:16, :]
interleaved = torch.stack([w_gate, w_up], dim=3).reshape(E, N, K_half)
return interleaved.contiguous()
def transform_nvfp4_weights_for_tilelang(
weight: torch.Tensor, # (E, N, K//2) int8, K-major packed FP4
weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn UE4M3
weight_scale_2: torch.Tensor, # (E, 1) or (E, 2) float32 global scale
N: int, # output dimension (2*INTER for L1, HIDDEN for L2)
K: int, # input dimension (HIDDEN for L1, INTER for L2)
gran_k: int = 16, # NVFP4 group_size
is_l1: bool = False, # True for gate_up_proj (needs interleaving)
logical_widths: list[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
l1_weight_scale_2: torch.Tensor = None, # float32 global scale for L1
l2_weight_scale_2: torch.Tensor = None, # float32 global scale for L2
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
"""Transform NVFP4 weights for the TileLang block-scaled GEMM.
Matches the call signature from nightly vLLM deepseek_v4.py finalize_weights.
Args:
l1_tuple: (w13_weight, w13_weight_scale) — gate_up proj
l2_tuple: (w2_weight, w2_weight_scale) — down proj
l1_weight_scale_2: global scale for L1 (float32)
l2_weight_scale_2: global scale for L2 (float32)
Returns:
transformed_weight: (E, N, K//2) int8, K-major, contiguously interleaved
transformed_sf: (E, N, K//64) uint32, packed UE4M3 with folded global scale
((l1_weight, l1_sf_packed), (l2_weight, l2_sf_packed))
"""
# Step 1: Fold global scale into block scales
sf_folded = fold_global_scale(weight_scale, weight_scale_2, logical_widths)
# Step 2: Clamp and convert back to UE4M3
sf_clamped = sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
# Step 3: Pack UE4M3 into uint32
sf_packed = pack_ue4m3_to_uint32(sf_clamped) # (E, N, K//64)
# Step 4: Interleave L1 weights if needed
if is_l1:
transformed_weight = interleave_l1_weights(weight)
l1_weight, l1_weight_scale = l1_tuple
l2_weight, l2_weight_scale = l2_tuple
# Fold global scales into block scales
if l1_weight_scale_2 is not None:
l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2,
logical_widths=[3072, 3072])
else:
transformed_weight = weight.contiguous()
return transformed_weight, sf_packed
l1_sf_folded = l1_weight_scale.to(torch.float32)
if l2_weight_scale_2 is not None:
l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2)
else:
l2_sf_folded = l2_weight_scale.to(torch.float32)
# Clamp and convert back to UE4M3
l1_sf_clamped = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
l2_sf_clamped = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
# Pack UE4M3 into uint32
l1_sf_packed = _pack_ue4m3_to_uint32(l1_sf_clamped)
l2_sf_packed = _pack_ue4m3_to_uint32(l2_sf_clamped)
# Interleave L1 weights (gate_up for 2CTA UMMA)
l1_weight_out = _interleave_l1_weights(l1_weight)
l2_weight_out = l2_weight.contiguous()
return (l1_weight_out, l1_sf_packed), (l2_weight_out, l2_sf_packed)
# Alias for drop-in replacement of deep_gemm.mega.transform_nvfp4_weights_for_mega_moe
# Alias for drop-in replacement
transform_nvfp4_weights_for_mega_moe = transform_nvfp4_weights_for_tilelang