fix: use transform_sf_into_required_layout for proper TMA-aligned SF
Instead of custom _pack_nvfp4_sf_for_utccp, use DeepGEMM's C++ transform_sf_into_required_layout with recipe (1, 1, 16) for NVFP4. This handles TMA alignment and UTCCP transpose correctly.
This commit is contained in:
@@ -144,41 +144,42 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Transform NVFP4 expert weights for the mega_moe kernel.
|
||||
|
||||
NVFP4 weights come as (weight, scale) where:
|
||||
- weight: uint8 E2M1 packed, shape (num_experts, N, K//2)
|
||||
- scale: float8_e4m3fn UE4M3 block scales, shape (num_experts, N, K//16)
|
||||
|
||||
The kernel expects (weight, packed_sf) where packed_sf is int32 UTCCP layout.
|
||||
|
||||
If weight_scale_2 (float32 global scale) is provided, it is folded into the
|
||||
block scales: effective_scale = block_scale * global_scale → re-quantized to UE4M3.
|
||||
This is needed because the kernel only applies one level of scaling.
|
||||
Uses deep_gemm.transform_sf_into_required_layout for proper TMA-aligned
|
||||
UTCCP layout with recipe (1, 1, 16) for NVFP4 group_size=16.
|
||||
"""
|
||||
from deep_gemm import transform_sf_into_required_layout
|
||||
|
||||
def fold_global_scale(sf: torch.Tensor, scale_2: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
"""Fold weight_scale_2 into block scales: UE4M3 * FP32 → UE4M3"""
|
||||
if scale_2 is None:
|
||||
return sf
|
||||
# Dequantize UE4M3 → float32
|
||||
sf_f32 = sf.to(torch.float32)
|
||||
# Multiply by global scale
|
||||
# scale_2 shape: (num_experts, 1, 1) or (num_experts,) — broadcast over mn and K
|
||||
if scale_2.dim() == 1:
|
||||
scale_2 = scale_2.view(-1, 1, 1)
|
||||
sf_f32 = sf_f32 * scale_2
|
||||
# Re-quantize to UE4M3 (clamp to [0, 448])
|
||||
sf_f32 = sf_f32.clamp(0.0, 448.0)
|
||||
return sf_f32.to(torch.float8_e4m3fn)
|
||||
|
||||
# Fold global scales into block scales
|
||||
l1_sf = fold_global_scale(l1_weights[1], l1_weight_scale_2)
|
||||
l2_sf = fold_global_scale(l2_weights[1], l2_weight_scale_2)
|
||||
|
||||
# L1: interleave gate/up, then pack + transpose SF for UTCCP
|
||||
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf))
|
||||
# DeepGEMM expects int8 (kPackedFP4 = torch.kInt8), but NVFP4 weights are uint8
|
||||
l1_out = (l1_interleaved[0].view(torch.int8), _pack_nvfp4_sf_for_utccp(l1_interleaved[1]))
|
||||
# L2: only pack + transpose SF for UTCCP
|
||||
l2_out = (l2_weights[0].view(torch.int8), _pack_nvfp4_sf_for_utccp(l2_sf))
|
||||
num_experts = l1_weights[0].shape[0]
|
||||
l1_n = l1_weights[0].shape[1] # intermediate_size * 2
|
||||
l1_k = l1_weights[0].shape[2] * 2 # K (weight is K//2 uint8)
|
||||
l2_n = l2_weights[0].shape[1]
|
||||
l2_k = l2_weights[0].shape[2] * 2
|
||||
|
||||
# Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function
|
||||
# recipe (1, 1, 16): gran_mn=1, gran_k=16
|
||||
l1_sf_packed = transform_sf_into_required_layout(
|
||||
l1_sf, l1_n, l1_k, (1, 16), num_experts)
|
||||
l2_sf_packed = transform_sf_into_required_layout(
|
||||
l2_sf, l2_n, l2_k, (1, 16), num_experts)
|
||||
|
||||
# L1: interleave gate/up
|
||||
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_packed))
|
||||
# DeepGEMM expects int8 (kPackedFP4 = torch.kInt8)
|
||||
l1_out = (l1_interleaved[0].view(torch.int8), l1_interleaved[1])
|
||||
l2_out = (l2_weights[0].view(torch.int8), l2_sf_packed)
|
||||
return l1_out, l2_out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user