Fix weight_transform signature to match nightly vLLM finalize_weights call
This commit is contained in:
@@ -7,112 +7,117 @@ into the TMA-aligned format expected by the block-scaled GEMM kernel:
|
||||
- Packed UE4M3 scales (uint32, TMA-aligned UTCCP layout)
|
||||
|
||||
This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe.
|
||||
|
||||
Call signature matches the nightly vLLM deepseek_v4.py finalize_weights:
|
||||
transform_nvfp4_weights_for_mega_moe(
|
||||
(l1_weight, l1_weight_scale),
|
||||
(l2_weight, l2_weight_scale),
|
||||
l1_weight_scale_2=...,
|
||||
l2_weight_scale_2=...,
|
||||
)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
|
||||
|
||||
def fold_global_scale(
|
||||
weight_scale: torch.Tensor, # (N, K//16) float8_e4m3fn
|
||||
weight_scale_2: torch.Tensor, # (num_logical,) or scalar float32
|
||||
logical_widths: list[int] = None, # per-logical-weight row counts
|
||||
def _fold_global_scale(
|
||||
weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn
|
||||
weight_scale_2: torch.Tensor, # (E, num_logical) or (E, 1) or scalar float32
|
||||
logical_widths: list[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Fold global scale into block scales: UE4M3 * FP32 → UE4M3 → float32.
|
||||
|
||||
Returns: (N, K//16) float32 folded block scales.
|
||||
"""
|
||||
"""Fold global scale into block scales: UE4M3 * FP32 → float32."""
|
||||
sf_f32 = weight_scale.to(torch.float32)
|
||||
|
||||
if weight_scale_2.numel() == 1:
|
||||
sf_f32 = sf_f32 * weight_scale_2.to(torch.float32)
|
||||
elif weight_scale_2.numel() > 1 and logical_widths is not None:
|
||||
# Per-logical-weight global scale (e.g., gate_up_proj has 2)
|
||||
gs = weight_scale_2.to(torch.float32)
|
||||
|
||||
if gs.numel() == 1:
|
||||
sf_f32 = sf_f32 * gs
|
||||
elif gs.numel() > 1 and logical_widths is not None:
|
||||
expanded = []
|
||||
for i, w in enumerate(logical_widths):
|
||||
if i < len(weight_scale_2):
|
||||
expanded.append(weight_scale_2[i].flatten()[0].expand(w))
|
||||
global_scale = torch.cat(expanded).to(torch.float32).unsqueeze(1)
|
||||
sf_f32 = sf_f32 * global_scale
|
||||
if i < len(gs.flatten()):
|
||||
expanded.append(gs.flatten()[i].expand(w))
|
||||
if expanded:
|
||||
global_scale = torch.cat(expanded).unsqueeze(1)
|
||||
sf_f32 = sf_f32 * global_scale
|
||||
else:
|
||||
sf_f32 = sf_f32 * weight_scale_2.max().to(torch.float32)
|
||||
|
||||
# Per-expert or per-row global scale — broadcast multiply
|
||||
while gs.dim() < sf_f32.dim():
|
||||
gs = gs.unsqueeze(-1)
|
||||
sf_f32 = sf_f32 * gs.expand_as(sf_f32)
|
||||
|
||||
return sf_f32
|
||||
|
||||
|
||||
def pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack 4 UE4M3 (float8_e4m3fn) values into one uint32.
|
||||
|
||||
Input: (..., K//16) float8_e4m3fn
|
||||
Output: (..., K//64) uint32 (4 values packed per word)
|
||||
"""
|
||||
# View as raw bytes
|
||||
sf_u8 = sf.view(torch.uint8) # (..., K//16) uint8
|
||||
shape = sf_u8.shape
|
||||
assert shape[-1] % 4 == 0, f"Last dim {shape[-1]} not divisible by 4"
|
||||
|
||||
# Pack 4 consecutive uint8 values into one uint32
|
||||
def _pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack 4 UE4M3 (float8_e4m3fn) values into one uint32."""
|
||||
sf_u8 = sf.view(torch.uint8)
|
||||
assert sf_u8.shape[-1] % 4 == 0, f"Last dim {sf_u8.shape[-1]} not divisible by 4"
|
||||
packed = (sf_u8[..., 0::4].to(torch.int32) |
|
||||
(sf_u8[..., 1::4].to(torch.int32) << 8) |
|
||||
(sf_u8[..., 2::4].to(torch.int32) << 16) |
|
||||
(sf_u8[..., 3::4].to(torch.int32) << 24))
|
||||
|
||||
return packed.contiguous()
|
||||
|
||||
|
||||
def interleave_l1_weights(
|
||||
weight: torch.Tensor, # (E, 2*INTER, K//2) int8, K-major
|
||||
) -> torch.Tensor:
|
||||
"""Interleave L1 (gate_up) weights for 2CTA UMMA.
|
||||
|
||||
The gate and up projections are interleaved in groups of 8 rows
|
||||
to match the UMMA 2CTA schedule.
|
||||
"""
|
||||
def _interleave_l1_weights(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Interleave L1 (gate_up) weights for 2CTA UMMA."""
|
||||
E, N, K_half = weight.shape
|
||||
assert N % 16 == 0, f"N={N} not divisible by 16"
|
||||
|
||||
# Reshape to (E, N//16, 16, K_half) → interleave pairs of 8
|
||||
w = weight.view(E, N // 16, 16, K_half)
|
||||
w_gate = w[:, :, :8, :]
|
||||
w_up = w[:, :, 8:16, :]
|
||||
interleaved = torch.stack([w_gate, w_up], dim=3).reshape(E, N, K_half)
|
||||
|
||||
return interleaved.contiguous()
|
||||
|
||||
|
||||
def transform_nvfp4_weights_for_tilelang(
|
||||
weight: torch.Tensor, # (E, N, K//2) int8, K-major packed FP4
|
||||
weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn UE4M3
|
||||
weight_scale_2: torch.Tensor, # (E, 1) or (E, 2) float32 global scale
|
||||
N: int, # output dimension (2*INTER for L1, HIDDEN for L2)
|
||||
K: int, # input dimension (HIDDEN for L1, INTER for L2)
|
||||
gran_k: int = 16, # NVFP4 group_size
|
||||
is_l1: bool = False, # True for gate_up_proj (needs interleaving)
|
||||
logical_widths: list[int] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
|
||||
l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
|
||||
l1_weight_scale_2: torch.Tensor = None, # float32 global scale for L1
|
||||
l2_weight_scale_2: torch.Tensor = None, # float32 global scale for L2
|
||||
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Transform NVFP4 weights for the TileLang block-scaled GEMM.
|
||||
|
||||
Matches the call signature from nightly vLLM deepseek_v4.py finalize_weights.
|
||||
|
||||
Args:
|
||||
l1_tuple: (w13_weight, w13_weight_scale) — gate_up proj
|
||||
l2_tuple: (w2_weight, w2_weight_scale) — down proj
|
||||
l1_weight_scale_2: global scale for L1 (float32)
|
||||
l2_weight_scale_2: global scale for L2 (float32)
|
||||
|
||||
Returns:
|
||||
transformed_weight: (E, N, K//2) int8, K-major, contiguously interleaved
|
||||
transformed_sf: (E, N, K//64) uint32, packed UE4M3 with folded global scale
|
||||
((l1_weight, l1_sf_packed), (l2_weight, l2_sf_packed))
|
||||
"""
|
||||
# Step 1: Fold global scale into block scales
|
||||
sf_folded = fold_global_scale(weight_scale, weight_scale_2, logical_widths)
|
||||
|
||||
# Step 2: Clamp and convert back to UE4M3
|
||||
sf_clamped = sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
|
||||
# Step 3: Pack UE4M3 into uint32
|
||||
sf_packed = pack_ue4m3_to_uint32(sf_clamped) # (E, N, K//64)
|
||||
|
||||
# Step 4: Interleave L1 weights if needed
|
||||
if is_l1:
|
||||
transformed_weight = interleave_l1_weights(weight)
|
||||
l1_weight, l1_weight_scale = l1_tuple
|
||||
l2_weight, l2_weight_scale = l2_tuple
|
||||
|
||||
# Fold global scales into block scales
|
||||
if l1_weight_scale_2 is not None:
|
||||
l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2,
|
||||
logical_widths=[3072, 3072])
|
||||
else:
|
||||
transformed_weight = weight.contiguous()
|
||||
|
||||
return transformed_weight, sf_packed
|
||||
l1_sf_folded = l1_weight_scale.to(torch.float32)
|
||||
|
||||
if l2_weight_scale_2 is not None:
|
||||
l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2)
|
||||
else:
|
||||
l2_sf_folded = l2_weight_scale.to(torch.float32)
|
||||
|
||||
# Clamp and convert back to UE4M3
|
||||
l1_sf_clamped = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
l2_sf_clamped = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
|
||||
# Pack UE4M3 into uint32
|
||||
l1_sf_packed = _pack_ue4m3_to_uint32(l1_sf_clamped)
|
||||
l2_sf_packed = _pack_ue4m3_to_uint32(l2_sf_clamped)
|
||||
|
||||
# Interleave L1 weights (gate_up for 2CTA UMMA)
|
||||
l1_weight_out = _interleave_l1_weights(l1_weight)
|
||||
l2_weight_out = l2_weight.contiguous()
|
||||
|
||||
return (l1_weight_out, l1_sf_packed), (l2_weight_out, l2_sf_packed)
|
||||
|
||||
|
||||
# Alias for drop-in replacement of deep_gemm.mega.transform_nvfp4_weights_for_mega_moe
|
||||
# Alias for drop-in replacement
|
||||
transform_nvfp4_weights_for_mega_moe = transform_nvfp4_weights_for_tilelang
|
||||
|
||||
Reference in New Issue
Block a user