diff --git a/patches/test_nvfp4_mega_moe.py b/patches/test_nvfp4_mega_moe.py index b956aaa..778b957 100644 --- a/patches/test_nvfp4_mega_moe.py +++ b/patches/test_nvfp4_mega_moe.py @@ -82,16 +82,10 @@ def test_nvfp4_mega_moe(): symm_buffer.x[:num_tokens].copy_( torch.randint(0, 256, (num_tokens, hidden // 2), dtype=torch.uint8, device=device).view(torch.int8)) # Write valid UE4M3 scales (random but non-zero) - # Can't randn with float8, so generate in bf16 then cast - sf_bf16 = torch.randn(num_tokens, hidden // 64, dtype=torch.bfloat16, device=device).abs().clamp(0.1, 5.0) - # Pack as int32 (4 UE4M3 bytes per int32) - sf_u8 = sf_bf16.to(torch.float8_e4m3fn).view(torch.uint8) - # Pack 4 uint8 into int32 - sf_packed = (sf_u8[..., 0::4].to(torch.int32) | - (sf_u8[..., 1::4].to(torch.int32) << 8) | - (sf_u8[..., 2::4].to(torch.int32) << 16) | - (sf_u8[..., 3::4].to(torch.int32) << 24)) - symm_buffer.x_sf[:num_tokens].copy_(sf_packed) + # x_sf shape is (tokens, hidden//64) as int32 — each int32 = 4 packed UE4M3 bytes + # Just fill with simple non-zero int32 values (the data doesn't need to be + # perfectly valid UE4M3 for a launch test, just non-garbage) + symm_buffer.x_sf[:num_tokens].fill_(0x3C3C3C3C) # repeating 0x3C = ~0.5 in E4M3 # Write topk data directly for i in range(num_tokens): for j in range(top_k):