"""Minimal test for fp8_nvfp4_mega_moe kernel with synthetic data.""" import torch import torch.distributed as dist import os def test_nvfp4_mega_moe(): # Small dimensions that satisfy alignment requirements # hidden and intermediate_hidden must be multiples of 128 # hidden must be divisible by 64 (for NVFP4 SF packing) num_experts = 2 num_tokens = 4 top_k = 2 hidden = 256 # must be multiple of 128 and 64 intermediate_hidden = 512 # must be multiple of 128 and 64 device = "cuda" torch.cuda.set_device(0) # Create a single-rank process group for SymmBuffer os.environ.setdefault("MASTER_ADDR", "127.0.0.1") os.environ.setdefault("MASTER_PORT", "29500") os.environ.setdefault("RANK", "0") os.environ.setdefault("WORLD_SIZE", "1") if not dist.is_initialized(): dist.init_process_group("nccl") group = dist.new_group() from deep_gemm.mega import ( fp8_nvfp4_mega_moe, get_symm_buffer_for_nvfp4_mega_moe, transform_nvfp4_weights_for_mega_moe, ) # Create random NVFP4 weights (E2M1 packed int8 + float8_e4m3fn block scales) # w13: (num_experts, 2*intermediate_hidden, hidden//2) w13_weight = torch.randint(0, 256, (num_experts, 2 * intermediate_hidden, hidden // 2), dtype=torch.uint8, device=device).view(torch.int8) w13_weight_scale = torch.randn(num_experts, 2 * intermediate_hidden, hidden // 16, device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn) w13_weight_scale_2 = torch.randn(num_experts, device=device).abs().clamp(0.5, 2.0) w13_input_scale = torch.ones(num_experts, device=device) # w2: (num_experts, hidden, intermediate_hidden//2) w2_weight = torch.randint(0, 256, (num_experts, hidden, intermediate_hidden // 2), dtype=torch.uint8, device=device).view(torch.int8) w2_weight_scale = torch.randn(num_experts, hidden, intermediate_hidden // 16, device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn) w2_weight_scale_2 = torch.randn(num_experts, device=device).abs().clamp(0.5, 2.0) w2_input_scale = torch.ones(num_experts, device=device) # Transform weights for the kernel l1_weights, l2_weights = transform_nvfp4_weights_for_mega_moe( (w13_weight, w13_weight_scale), (w2_weight, w2_weight_scale), l1_weight_scale_2=w13_weight_scale_2, l2_weight_scale_2=w2_weight_scale_2, ) print(f"l1_weights: dtype={l1_weights[0].dtype} shape={l1_weights[0].shape} strides={l1_weights[0].stride()}") print(f"l1_sf: dtype={l1_weights[1].dtype} shape={l1_weights[1].shape} strides={l1_weights[1].stride()}") print(f"l2_weights: dtype={l2_weights[0].dtype} shape={l2_weights[0].shape} strides={l2_weights[0].stride()}") print(f"l2_sf: dtype={l2_weights[1].dtype} shape={l2_weights[1].shape} strides={l2_weights[1].stride()}") # Create symm buffer symm_buffer = get_symm_buffer_for_nvfp4_mega_moe( group, num_experts, num_tokens, top_k, hidden, intermediate_hidden) # Create input (BF16) hidden_states = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device) # Create topk weights/ids topk_weights = torch.softmax(torch.randn(num_tokens, top_k, device=device), dim=-1) topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=device) # Stage inputs from deepseek_v4_staging import _stage_deepseek_v4_mega_moe_inputs # Actually, we can't import from vllm patch. Let's just manually set up the symm buffer. # Output tensor y = torch.zeros(num_tokens, hidden, dtype=torch.bfloat16, device=device) # Call the kernel print("Calling fp8_nvfp4_mega_moe...") try: fp8_nvfp4_mega_moe( y, l1_weights, l2_weights, symm_buffer, ) print("SUCCESS! y stats: min={:.4f} max={:.4f} mean={:.4f}".format( y.min().item(), y.max().item(), y.mean().item())) except Exception as e: print(f"FAILED: {e}") raise if __name__ == "__main__": test_nvfp4_mega_moe()