test: simplify SF fill to avoid shape mismatch

This commit is contained in:
2026-05-12 15:13:16 +00:00
parent d4c557fddc
commit fcd6de0a60

View File

@@ -82,16 +82,10 @@ def test_nvfp4_mega_moe():
symm_buffer.x[:num_tokens].copy_(
torch.randint(0, 256, (num_tokens, hidden // 2), dtype=torch.uint8, device=device).view(torch.int8))
# Write valid UE4M3 scales (random but non-zero)
# Can't randn with float8, so generate in bf16 then cast
sf_bf16 = torch.randn(num_tokens, hidden // 64, dtype=torch.bfloat16, device=device).abs().clamp(0.1, 5.0)
# Pack as int32 (4 UE4M3 bytes per int32)
sf_u8 = sf_bf16.to(torch.float8_e4m3fn).view(torch.uint8)
# Pack 4 uint8 into int32
sf_packed = (sf_u8[..., 0::4].to(torch.int32) |
(sf_u8[..., 1::4].to(torch.int32) << 8) |
(sf_u8[..., 2::4].to(torch.int32) << 16) |
(sf_u8[..., 3::4].to(torch.int32) << 24))
symm_buffer.x_sf[:num_tokens].copy_(sf_packed)
# x_sf shape is (tokens, hidden//64) as int32 — each int32 = 4 packed UE4M3 bytes
# Just fill with simple non-zero int32 values (the data doesn't need to be
# perfectly valid UE4M3 for a launch test, just non-garbage)
symm_buffer.x_sf[:num_tokens].fill_(0x3C3C3C3C) # repeating 0x3C = ~0.5 in E4M3
# Write topk data directly
for i in range(num_tokens):
for j in range(top_k):