Files
DeepGEMM/test_nvfp4_mega_moe.py
biondizzle 74bf612771 NVFP4 mega MoE: sf_id=0 fix for scale_vec::4X + UINT8 TMA + SF pipeline + interleaving
Root cause of ILLEGAL_INSTRUCTION: make_runtime_instr_desc_with_sf_id(instr_desc, k, k)
passed sf_id=1 for k=1 (second UMMA atom), but mxf4nvf4 with scale_vec::4X requires
sf_id=0 always — the hardware implicitly reads 4 SF positions per atom from a single
TMEM region. Non-zero sf_id causes the hardware to access invalid TMEM offsets.

Also includes:
- UINT8 TMA for packed FP4 (avoids 16U4 driver bugs)
- NVFP4 SF pipeline: 2 K-columns per BLOCK_K for group_size=16
- MN-major SF interleaving for gate/up L1 weights
- Fix contiguous copy for SF byte view
- Preserve MN-major layout in SF interleave
- Force contiguous on SF tensors before C++ call
- Unpack weight tuples before printing
- Single transpose back to MN-major (don't double-transpose)
2026-05-12 20:26:13 +00:00

98 lines
4.1 KiB
Python

"""Minimal test for fp8_nvfp4_mega_moe kernel with synthetic data."""
import torch
import torch.distributed as dist
import os
def test_nvfp4_mega_moe():
# Small dimensions that satisfy alignment requirements
# hidden and intermediate_hidden must be multiples of 128
# hidden must be divisible by 64 (for NVFP4 SF packing)
num_experts = 2
num_tokens = 4
top_k = 2
hidden = 256 # must be multiple of 128 and 64
intermediate_hidden = 512 # must be multiple of 128 and 64
device = "cuda"
torch.cuda.set_device(0)
# Create a single-rank process group for SymmBuffer
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29500")
os.environ.setdefault("RANK", "0")
os.environ.setdefault("WORLD_SIZE", "1")
if not dist.is_initialized():
dist.init_process_group("nccl")
group = dist.new_group()
from deep_gemm.mega import (
fp8_nvfp4_mega_moe,
get_symm_buffer_for_nvfp4_mega_moe,
transform_nvfp4_weights_for_mega_moe,
)
# Create random NVFP4 weights (E2M1 packed int8 + float8_e4m3fn block scales)
# w13: (num_experts, 2*intermediate_hidden, hidden//2)
w13_weight = torch.randint(0, 256, (num_experts, 2 * intermediate_hidden, hidden // 2),
dtype=torch.uint8, device=device).view(torch.int8)
w13_weight_scale = torch.randn(num_experts, 2 * intermediate_hidden, hidden // 16,
device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn)
w13_weight_scale_2 = torch.randn(num_experts, device=device).abs().clamp(0.5, 2.0)
w13_input_scale = torch.ones(num_experts, device=device)
# w2: (num_experts, hidden, intermediate_hidden//2)
w2_weight = torch.randint(0, 256, (num_experts, hidden, intermediate_hidden // 2),
dtype=torch.uint8, device=device).view(torch.int8)
w2_weight_scale = torch.randn(num_experts, hidden, intermediate_hidden // 16,
device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn)
w2_weight_scale_2 = torch.randn(num_experts, device=device).abs().clamp(0.5, 2.0)
w2_input_scale = torch.ones(num_experts, device=device)
# Transform weights for the kernel
l1_weights, l2_weights = transform_nvfp4_weights_for_mega_moe(
(w13_weight, w13_weight_scale),
(w2_weight, w2_weight_scale),
l1_weight_scale_2=w13_weight_scale_2,
l2_weight_scale_2=w2_weight_scale_2,
)
print(f"l1_weights: dtype={l1_weights[0].dtype} shape={l1_weights[0].shape} strides={l1_weights[0].stride()}")
print(f"l1_sf: dtype={l1_weights[1].dtype} shape={l1_weights[1].shape} strides={l1_weights[1].stride()}")
print(f"l2_weights: dtype={l2_weights[0].dtype} shape={l2_weights[0].shape} strides={l2_weights[0].stride()}")
print(f"l2_sf: dtype={l2_weights[1].dtype} shape={l2_weights[1].shape} strides={l2_weights[1].stride()}")
# Create symm buffer
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
group, num_experts, num_tokens, top_k, hidden, intermediate_hidden)
# Create input (BF16)
hidden_states = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device)
# Create topk weights/ids
topk_weights = torch.softmax(torch.randn(num_tokens, top_k, device=device), dim=-1)
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=device)
# Stage inputs
from deepseek_v4_staging import _stage_deepseek_v4_mega_moe_inputs
# Actually, we can't import from vllm patch. Let's just manually set up the symm buffer.
# Output tensor
y = torch.zeros(num_tokens, hidden, dtype=torch.bfloat16, device=device)
# Call the kernel
print("Calling fp8_nvfp4_mega_moe...")
try:
fp8_nvfp4_mega_moe(
y,
l1_weights, l2_weights,
symm_buffer,
)
print("SUCCESS! y stats: min={:.4f} max={:.4f} mean={:.4f}".format(
y.min().item(), y.max().item(), y.mean().item()))
except Exception as e:
print(f"FAILED: {e}")
raise
if __name__ == "__main__":
test_nvfp4_mega_moe()