From 1a452ffabd1205085553b090e8936f65a7d1ca6e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 13 May 2026 21:41:43 +0000 Subject: [PATCH] Fix weight_transform signature to match nightly vLLM finalize_weights call --- src/nvfp4_megamoe_kernel/weight_transform.py | 147 ++++++++++--------- 1 file changed, 76 insertions(+), 71 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/weight_transform.py b/src/nvfp4_megamoe_kernel/weight_transform.py index 8ab2cad9..90fc3018 100644 --- a/src/nvfp4_megamoe_kernel/weight_transform.py +++ b/src/nvfp4_megamoe_kernel/weight_transform.py @@ -7,112 +7,117 @@ into the TMA-aligned format expected by the block-scaled GEMM kernel: - Packed UE4M3 scales (uint32, TMA-aligned UTCCP layout) This replaces deep_gemm.mega.transform_nvfp4_weights_for_mega_moe. + +Call signature matches the nightly vLLM deepseek_v4.py finalize_weights: + transform_nvfp4_weights_for_mega_moe( + (l1_weight, l1_weight_scale), + (l2_weight, l2_weight_scale), + l1_weight_scale_2=..., + l2_weight_scale_2=..., + ) """ import torch -import math -def fold_global_scale( - weight_scale: torch.Tensor, # (N, K//16) float8_e4m3fn - weight_scale_2: torch.Tensor, # (num_logical,) or scalar float32 - logical_widths: list[int] = None, # per-logical-weight row counts +def _fold_global_scale( + weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn + weight_scale_2: torch.Tensor, # (E, num_logical) or (E, 1) or scalar float32 + logical_widths: list[int] = None, ) -> torch.Tensor: - """Fold global scale into block scales: UE4M3 * FP32 → UE4M3 → float32. - - Returns: (N, K//16) float32 folded block scales. - """ + """Fold global scale into block scales: UE4M3 * FP32 → float32.""" sf_f32 = weight_scale.to(torch.float32) - - if weight_scale_2.numel() == 1: - sf_f32 = sf_f32 * weight_scale_2.to(torch.float32) - elif weight_scale_2.numel() > 1 and logical_widths is not None: - # Per-logical-weight global scale (e.g., gate_up_proj has 2) + gs = weight_scale_2.to(torch.float32) + + if gs.numel() == 1: + sf_f32 = sf_f32 * gs + elif gs.numel() > 1 and logical_widths is not None: expanded = [] for i, w in enumerate(logical_widths): - if i < len(weight_scale_2): - expanded.append(weight_scale_2[i].flatten()[0].expand(w)) - global_scale = torch.cat(expanded).to(torch.float32).unsqueeze(1) - sf_f32 = sf_f32 * global_scale + if i < len(gs.flatten()): + expanded.append(gs.flatten()[i].expand(w)) + if expanded: + global_scale = torch.cat(expanded).unsqueeze(1) + sf_f32 = sf_f32 * global_scale else: - sf_f32 = sf_f32 * weight_scale_2.max().to(torch.float32) - + # Per-expert or per-row global scale — broadcast multiply + while gs.dim() < sf_f32.dim(): + gs = gs.unsqueeze(-1) + sf_f32 = sf_f32 * gs.expand_as(sf_f32) + return sf_f32 -def pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor: - """Pack 4 UE4M3 (float8_e4m3fn) values into one uint32. - - Input: (..., K//16) float8_e4m3fn - Output: (..., K//64) uint32 (4 values packed per word) - """ - # View as raw bytes - sf_u8 = sf.view(torch.uint8) # (..., K//16) uint8 - shape = sf_u8.shape - assert shape[-1] % 4 == 0, f"Last dim {shape[-1]} not divisible by 4" - - # Pack 4 consecutive uint8 values into one uint32 +def _pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor: + """Pack 4 UE4M3 (float8_e4m3fn) values into one uint32.""" + sf_u8 = sf.view(torch.uint8) + assert sf_u8.shape[-1] % 4 == 0, f"Last dim {sf_u8.shape[-1]} not divisible by 4" packed = (sf_u8[..., 0::4].to(torch.int32) | (sf_u8[..., 1::4].to(torch.int32) << 8) | (sf_u8[..., 2::4].to(torch.int32) << 16) | (sf_u8[..., 3::4].to(torch.int32) << 24)) - return packed.contiguous() -def interleave_l1_weights( - weight: torch.Tensor, # (E, 2*INTER, K//2) int8, K-major -) -> torch.Tensor: - """Interleave L1 (gate_up) weights for 2CTA UMMA. - - The gate and up projections are interleaved in groups of 8 rows - to match the UMMA 2CTA schedule. - """ +def _interleave_l1_weights(weight: torch.Tensor) -> torch.Tensor: + """Interleave L1 (gate_up) weights for 2CTA UMMA.""" E, N, K_half = weight.shape assert N % 16 == 0, f"N={N} not divisible by 16" - - # Reshape to (E, N//16, 16, K_half) → interleave pairs of 8 w = weight.view(E, N // 16, 16, K_half) w_gate = w[:, :, :8, :] w_up = w[:, :, 8:16, :] interleaved = torch.stack([w_gate, w_up], dim=3).reshape(E, N, K_half) - return interleaved.contiguous() def transform_nvfp4_weights_for_tilelang( - weight: torch.Tensor, # (E, N, K//2) int8, K-major packed FP4 - weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn UE4M3 - weight_scale_2: torch.Tensor, # (E, 1) or (E, 2) float32 global scale - N: int, # output dimension (2*INTER for L1, HIDDEN for L2) - K: int, # input dimension (HIDDEN for L1, INTER for L2) - gran_k: int = 16, # NVFP4 group_size - is_l1: bool = False, # True for gate_up_proj (needs interleaving) - logical_widths: list[int] = None, -) -> tuple[torch.Tensor, torch.Tensor]: + l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale) + l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale) + l1_weight_scale_2: torch.Tensor = None, # float32 global scale for L1 + l2_weight_scale_2: torch.Tensor = None, # float32 global scale for L2 +) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: """Transform NVFP4 weights for the TileLang block-scaled GEMM. + Matches the call signature from nightly vLLM deepseek_v4.py finalize_weights. + + Args: + l1_tuple: (w13_weight, w13_weight_scale) — gate_up proj + l2_tuple: (w2_weight, w2_weight_scale) — down proj + l1_weight_scale_2: global scale for L1 (float32) + l2_weight_scale_2: global scale for L2 (float32) + Returns: - transformed_weight: (E, N, K//2) int8, K-major, contiguously interleaved - transformed_sf: (E, N, K//64) uint32, packed UE4M3 with folded global scale + ((l1_weight, l1_sf_packed), (l2_weight, l2_sf_packed)) """ - # Step 1: Fold global scale into block scales - sf_folded = fold_global_scale(weight_scale, weight_scale_2, logical_widths) - - # Step 2: Clamp and convert back to UE4M3 - sf_clamped = sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) - - # Step 3: Pack UE4M3 into uint32 - sf_packed = pack_ue4m3_to_uint32(sf_clamped) # (E, N, K//64) - - # Step 4: Interleave L1 weights if needed - if is_l1: - transformed_weight = interleave_l1_weights(weight) + l1_weight, l1_weight_scale = l1_tuple + l2_weight, l2_weight_scale = l2_tuple + + # Fold global scales into block scales + if l1_weight_scale_2 is not None: + l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2, + logical_widths=[3072, 3072]) else: - transformed_weight = weight.contiguous() - - return transformed_weight, sf_packed + l1_sf_folded = l1_weight_scale.to(torch.float32) + + if l2_weight_scale_2 is not None: + l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2) + else: + l2_sf_folded = l2_weight_scale.to(torch.float32) + + # Clamp and convert back to UE4M3 + l1_sf_clamped = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) + l2_sf_clamped = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn) + + # Pack UE4M3 into uint32 + l1_sf_packed = _pack_ue4m3_to_uint32(l1_sf_clamped) + l2_sf_packed = _pack_ue4m3_to_uint32(l2_sf_clamped) + + # Interleave L1 weights (gate_up for 2CTA UMMA) + l1_weight_out = _interleave_l1_weights(l1_weight) + l2_weight_out = l2_weight.contiguous() + + return (l1_weight_out, l1_sf_packed), (l2_weight_out, l2_sf_packed) -# Alias for drop-in replacement of deep_gemm.mega.transform_nvfp4_weights_for_mega_moe +# Alias for drop-in replacement transform_nvfp4_weights_for_mega_moe = transform_nvfp4_weights_for_tilelang