diff --git a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp index 17948cd..12a78d3 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_nvfp4_mega_moe.hpp @@ -157,10 +157,10 @@ static void sm100_fp8_nvfp4_mega_moe( intermediate_hidden * 2, hidden, config.block_n, kGranK, num_experts_per_rank, 0); - // L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/4 bytes, no swizzle (v1) + // L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/2 bytes (packed), no swizzle (v1) const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts, intermediate_hidden / 2, config.num_max_pool_tokens, - config.block_n / 4, config.store_block_m, + config.block_n / 2, config.store_block_m, static_cast(l2_acts.stride(-2)), 0, 0, // no swizzle false, // allow_tf32 diff --git a/test_nvfp4_mega_moe.py b/test_nvfp4_mega_moe.py new file mode 100644 index 0000000..c2aab23 --- /dev/null +++ b/test_nvfp4_mega_moe.py @@ -0,0 +1,97 @@ +"""Minimal test for fp8_nvfp4_mega_moe kernel with synthetic data.""" +import torch +import torch.distributed as dist +import os + +def test_nvfp4_mega_moe(): + # Small dimensions that satisfy alignment requirements + # hidden and intermediate_hidden must be multiples of 128 + # hidden must be divisible by 64 (for NVFP4 SF packing) + num_experts = 2 + num_tokens = 4 + top_k = 2 + hidden = 256 # must be multiple of 128 and 64 + intermediate_hidden = 512 # must be multiple of 128 and 64 + + device = "cuda" + torch.cuda.set_device(0) + + # Create a single-rank process group for SymmBuffer + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29500") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + if not dist.is_initialized(): + dist.init_process_group("nccl") + group = dist.new_group() + + from deep_gemm.mega import ( + fp8_nvfp4_mega_moe, + get_symm_buffer_for_nvfp4_mega_moe, + transform_nvfp4_weights_for_mega_moe, + ) + + # Create random NVFP4 weights (E2M1 packed int8 + float8_e4m3fn block scales) + # w13: (num_experts, 2*intermediate_hidden, hidden//2) + w13_weight = torch.randint(0, 256, (num_experts, 2 * intermediate_hidden, hidden // 2), + dtype=torch.uint8, device=device).view(torch.int8) + w13_weight_scale = torch.randn(num_experts, 2 * intermediate_hidden, hidden // 16, + device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn) + w13_weight_scale_2 = torch.randn(num_experts, device=device).abs().clamp(0.5, 2.0) + w13_input_scale = torch.ones(num_experts, device=device) + + # w2: (num_experts, hidden, intermediate_hidden//2) + w2_weight = torch.randint(0, 256, (num_experts, hidden, intermediate_hidden // 2), + dtype=torch.uint8, device=device).view(torch.int8) + w2_weight_scale = torch.randn(num_experts, hidden, intermediate_hidden // 16, + device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn) + w2_weight_scale_2 = torch.randn(num_experts, device=device).abs().clamp(0.5, 2.0) + w2_input_scale = torch.ones(num_experts, device=device) + + # Transform weights for the kernel + l1_weights, l2_weights = transform_nvfp4_weights_for_mega_moe( + (w13_weight, w13_weight_scale), + (w2_weight, w2_weight_scale), + l1_weight_scale_2=w13_weight_scale_2, + l2_weight_scale_2=w2_weight_scale_2, + ) + + print(f"l1_weights: dtype={l1_weights[0].dtype} shape={l1_weights[0].shape} strides={l1_weights[0].stride()}") + print(f"l1_sf: dtype={l1_weights[1].dtype} shape={l1_weights[1].shape} strides={l1_weights[1].stride()}") + print(f"l2_weights: dtype={l2_weights[0].dtype} shape={l2_weights[0].shape} strides={l2_weights[0].stride()}") + print(f"l2_sf: dtype={l2_weights[1].dtype} shape={l2_weights[1].shape} strides={l2_weights[1].stride()}") + + # Create symm buffer + symm_buffer = get_symm_buffer_for_nvfp4_mega_moe( + group, num_experts, num_tokens, top_k, hidden, intermediate_hidden) + + # Create input (BF16) + hidden_states = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device) + + # Create topk weights/ids + topk_weights = torch.softmax(torch.randn(num_tokens, top_k, device=device), dim=-1) + topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=device) + + # Stage inputs + from deepseek_v4_staging import _stage_deepseek_v4_mega_moe_inputs + # Actually, we can't import from vllm patch. Let's just manually set up the symm buffer. + + # Output tensor + y = torch.zeros(num_tokens, hidden, dtype=torch.bfloat16, device=device) + + # Call the kernel + print("Calling fp8_nvfp4_mega_moe...") + try: + fp8_nvfp4_mega_moe( + y, + l1_weights, l2_weights, + symm_buffer, + ) + print("SUCCESS! y stats: min={:.4f} max={:.4f} mean={:.4f}".format( + y.min().item(), y.max().item(), y.mean().item())) + except Exception as e: + print(f"FAILED: {e}") + raise + +if __name__ == "__main__": + test_nvfp4_mega_moe()