test: simplify SF fill to avoid shape mismatch
This commit is contained in:
@@ -82,16 +82,10 @@ def test_nvfp4_mega_moe():
|
||||
symm_buffer.x[:num_tokens].copy_(
|
||||
torch.randint(0, 256, (num_tokens, hidden // 2), dtype=torch.uint8, device=device).view(torch.int8))
|
||||
# Write valid UE4M3 scales (random but non-zero)
|
||||
# Can't randn with float8, so generate in bf16 then cast
|
||||
sf_bf16 = torch.randn(num_tokens, hidden // 64, dtype=torch.bfloat16, device=device).abs().clamp(0.1, 5.0)
|
||||
# Pack as int32 (4 UE4M3 bytes per int32)
|
||||
sf_u8 = sf_bf16.to(torch.float8_e4m3fn).view(torch.uint8)
|
||||
# Pack 4 uint8 into int32
|
||||
sf_packed = (sf_u8[..., 0::4].to(torch.int32) |
|
||||
(sf_u8[..., 1::4].to(torch.int32) << 8) |
|
||||
(sf_u8[..., 2::4].to(torch.int32) << 16) |
|
||||
(sf_u8[..., 3::4].to(torch.int32) << 24))
|
||||
symm_buffer.x_sf[:num_tokens].copy_(sf_packed)
|
||||
# x_sf shape is (tokens, hidden//64) as int32 — each int32 = 4 packed UE4M3 bytes
|
||||
# Just fill with simple non-zero int32 values (the data doesn't need to be
|
||||
# perfectly valid UE4M3 for a launch test, just non-garbage)
|
||||
symm_buffer.x_sf[:num_tokens].fill_(0x3C3C3C3C) # repeating 0x3C = ~0.5 in E4M3
|
||||
# Write topk data directly
|
||||
for i in range(num_tokens):
|
||||
for j in range(top_k):
|
||||
|
||||
Reference in New Issue
Block a user