Packed E2M1 output has 2 elements per byte, so block_n elements = block_n/2 bytes. block_n/4 was under-sizing the TMA SMEM row by 2x → OOB write → LAUNCH_FAILED.
98 lines
4.1 KiB
Python
98 lines
4.1 KiB
Python
"""Minimal test for fp8_nvfp4_mega_moe kernel with synthetic data."""
|
|
import torch
|
|
import torch.distributed as dist
|
|
import os
|
|
|
|
def test_nvfp4_mega_moe():
|
|
# Small dimensions that satisfy alignment requirements
|
|
# hidden and intermediate_hidden must be multiples of 128
|
|
# hidden must be divisible by 64 (for NVFP4 SF packing)
|
|
num_experts = 2
|
|
num_tokens = 4
|
|
top_k = 2
|
|
hidden = 256 # must be multiple of 128 and 64
|
|
intermediate_hidden = 512 # must be multiple of 128 and 64
|
|
|
|
device = "cuda"
|
|
torch.cuda.set_device(0)
|
|
|
|
# Create a single-rank process group for SymmBuffer
|
|
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
|
|
os.environ.setdefault("MASTER_PORT", "29500")
|
|
os.environ.setdefault("RANK", "0")
|
|
os.environ.setdefault("WORLD_SIZE", "1")
|
|
if not dist.is_initialized():
|
|
dist.init_process_group("nccl")
|
|
group = dist.new_group()
|
|
|
|
from deep_gemm.mega import (
|
|
fp8_nvfp4_mega_moe,
|
|
get_symm_buffer_for_nvfp4_mega_moe,
|
|
transform_nvfp4_weights_for_mega_moe,
|
|
)
|
|
|
|
# Create random NVFP4 weights (E2M1 packed int8 + float8_e4m3fn block scales)
|
|
# w13: (num_experts, 2*intermediate_hidden, hidden//2)
|
|
w13_weight = torch.randint(0, 256, (num_experts, 2 * intermediate_hidden, hidden // 2),
|
|
dtype=torch.uint8, device=device).view(torch.int8)
|
|
w13_weight_scale = torch.randn(num_experts, 2 * intermediate_hidden, hidden // 16,
|
|
device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn)
|
|
w13_weight_scale_2 = torch.randn(num_experts, device=device).abs().clamp(0.5, 2.0)
|
|
w13_input_scale = torch.ones(num_experts, device=device)
|
|
|
|
# w2: (num_experts, hidden, intermediate_hidden//2)
|
|
w2_weight = torch.randint(0, 256, (num_experts, hidden, intermediate_hidden // 2),
|
|
dtype=torch.uint8, device=device).view(torch.int8)
|
|
w2_weight_scale = torch.randn(num_experts, hidden, intermediate_hidden // 16,
|
|
device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn)
|
|
w2_weight_scale_2 = torch.randn(num_experts, device=device).abs().clamp(0.5, 2.0)
|
|
w2_input_scale = torch.ones(num_experts, device=device)
|
|
|
|
# Transform weights for the kernel
|
|
l1_weights, l2_weights = transform_nvfp4_weights_for_mega_moe(
|
|
(w13_weight, w13_weight_scale),
|
|
(w2_weight, w2_weight_scale),
|
|
l1_weight_scale_2=w13_weight_scale_2,
|
|
l2_weight_scale_2=w2_weight_scale_2,
|
|
)
|
|
|
|
print(f"l1_weights: dtype={l1_weights[0].dtype} shape={l1_weights[0].shape} strides={l1_weights[0].stride()}")
|
|
print(f"l1_sf: dtype={l1_weights[1].dtype} shape={l1_weights[1].shape} strides={l1_weights[1].stride()}")
|
|
print(f"l2_weights: dtype={l2_weights[0].dtype} shape={l2_weights[0].shape} strides={l2_weights[0].stride()}")
|
|
print(f"l2_sf: dtype={l2_weights[1].dtype} shape={l2_weights[1].shape} strides={l2_weights[1].stride()}")
|
|
|
|
# Create symm buffer
|
|
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
|
|
group, num_experts, num_tokens, top_k, hidden, intermediate_hidden)
|
|
|
|
# Create input (BF16)
|
|
hidden_states = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device)
|
|
|
|
# Create topk weights/ids
|
|
topk_weights = torch.softmax(torch.randn(num_tokens, top_k, device=device), dim=-1)
|
|
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=device)
|
|
|
|
# Stage inputs
|
|
from deepseek_v4_staging import _stage_deepseek_v4_mega_moe_inputs
|
|
# Actually, we can't import from vllm patch. Let's just manually set up the symm buffer.
|
|
|
|
# Output tensor
|
|
y = torch.zeros(num_tokens, hidden, dtype=torch.bfloat16, device=device)
|
|
|
|
# Call the kernel
|
|
print("Calling fp8_nvfp4_mega_moe...")
|
|
try:
|
|
fp8_nvfp4_mega_moe(
|
|
y,
|
|
l1_weights, l2_weights,
|
|
symm_buffer,
|
|
)
|
|
print("SUCCESS! y stats: min={:.4f} max={:.4f} mean={:.4f}".format(
|
|
y.min().item(), y.max().item(), y.mean().item()))
|
|
except Exception as e:
|
|
print(f"FAILED: {e}")
|
|
raise
|
|
|
|
if __name__ == "__main__":
|
|
test_nvfp4_mega_moe()
|