tweax n shit

This commit is contained in:
2026-05-12 23:16:44 +00:00
parent 8b27e85ee5
commit ad335c38fb
2 changed files with 4 additions and 33 deletions

View File

@@ -187,35 +187,6 @@ def transform_weights_for_mega_moe(
return l1_weights, l2_weights
def _pack_nvfp4_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
"""Pack NVFP4 UE4M3 block scales (float8_e4m3fn) into int32 UTCCP layout.
NVFP4 uses UE4M3 scales with group_size=16 (scale_vec::4X).
The UTCCP layout packs 4 consecutive scale bytes into each int32,
then applies the 4x32 transpose for TMA consumption.
Input: (num_experts, mn, K//16) float8_e4m3fn scales
Output: (num_experts, mn, K//64) int32 packed UTCCP-transposed scales
"""
num_groups, mn, sf_k = sf.shape
assert sf_k % 4 == 0, f"NVFP4 SF K dim must be divisible by 4, got {sf_k}"
assert mn % 128 == 0, f"MN dim must be divisible by 128, got {mn}"
# View as uint8 and pack 4 consecutive bytes into int32
sf_uint8 = sf.view(torch.uint8) # (num_groups, mn, sf_k)
# Pack: every 4 uint8 → 1 int32
packed = (sf_uint8[..., 0::4].to(torch.int32) |
(sf_uint8[..., 1::4].to(torch.int32) << 8) |
(sf_uint8[..., 2::4].to(torch.int32) << 16) |
(sf_uint8[..., 3::4].to(torch.int32) << 24)) # (num_groups, mn, sf_k//4)
# Apply UTCCP 4x32 transpose (same as MXFP4 — the transpose is determined
# by the 128-element alignment, not the scale vector size)
packed_sf_k = sf_k // 4
result = (packed.reshape(num_groups, -1, 4, 32, packed_sf_k)
.transpose(2, 3)
.reshape(num_groups, mn, packed_sf_k))
return torch.empty_like(packed).copy_(result)
def transform_nvfp4_weights_for_mega_moe(