diff --git a/patches/test_nvfp4_mega_moe.py b/patches/test_nvfp4_mega_moe.py index 0c733bd..c6ff286 100644 --- a/patches/test_nvfp4_mega_moe.py +++ b/patches/test_nvfp4_mega_moe.py @@ -9,12 +9,14 @@ import os import sys def test_nvfp4_mega_moe(): - # Small but aligned dimensions + # Use dimensions that satisfy all alignment requirements: + # - hidden and intermediate_hidden must be multiples of 128 and 64 + # - block_m will be at least 32 (SMEM alignment: 32 * 64 = 2048 >= 1024) num_experts = 2 - num_tokens = 8 # must be multiple of alignment (8 for block_m=8) + num_tokens = 32 # must be multiple of alignment top_k = 2 - hidden = 256 # must be multiple of 128 and 64 - intermediate_hidden = 512 # must be multiple of 128 and 64 + hidden = 512 # multiple of 128 and 64 + intermediate_hidden = 1024 # multiple of 128 and 64 device = "cuda" torch.cuda.set_device(0)