fix: L1 output TMA smem_inner_dim was block_n/4, should be block_n/2
Packed E2M1 output has 2 elements per byte, so block_n elements = block_n/2 bytes. block_n/4 was under-sizing the TMA SMEM row by 2x → OOB write → LAUNCH_FAILED.
This commit is contained in:
@@ -157,10 +157,10 @@ static void sm100_fp8_nvfp4_mega_moe(
|
||||
intermediate_hidden * 2, hidden,
|
||||
config.block_n, kGranK,
|
||||
num_experts_per_rank, 0);
|
||||
// L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/4 bytes, no swizzle (v1)
|
||||
// L1 output: packed E2M1, K-dim = intermediate_hidden/2, inner = block_n/2 bytes (packed), no swizzle (v1)
|
||||
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
|
||||
intermediate_hidden / 2, config.num_max_pool_tokens,
|
||||
config.block_n / 4, config.store_block_m,
|
||||
config.block_n / 2, config.store_block_m,
|
||||
static_cast<int>(l2_acts.stride(-2)),
|
||||
0, 0, // no swizzle
|
||||
false, // allow_tf32
|
||||
|
||||
97
test_nvfp4_mega_moe.py
Normal file
97
test_nvfp4_mega_moe.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Minimal test for fp8_nvfp4_mega_moe kernel with synthetic data."""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import os
|
||||
|
||||
def test_nvfp4_mega_moe():
|
||||
# Small dimensions that satisfy alignment requirements
|
||||
# hidden and intermediate_hidden must be multiples of 128
|
||||
# hidden must be divisible by 64 (for NVFP4 SF packing)
|
||||
num_experts = 2
|
||||
num_tokens = 4
|
||||
top_k = 2
|
||||
hidden = 256 # must be multiple of 128 and 64
|
||||
intermediate_hidden = 512 # must be multiple of 128 and 64
|
||||
|
||||
device = "cuda"
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# Create a single-rank process group for SymmBuffer
|
||||
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
|
||||
os.environ.setdefault("MASTER_PORT", "29500")
|
||||
os.environ.setdefault("RANK", "0")
|
||||
os.environ.setdefault("WORLD_SIZE", "1")
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group("nccl")
|
||||
group = dist.new_group()
|
||||
|
||||
from deep_gemm.mega import (
|
||||
fp8_nvfp4_mega_moe,
|
||||
get_symm_buffer_for_nvfp4_mega_moe,
|
||||
transform_nvfp4_weights_for_mega_moe,
|
||||
)
|
||||
|
||||
# Create random NVFP4 weights (E2M1 packed int8 + float8_e4m3fn block scales)
|
||||
# w13: (num_experts, 2*intermediate_hidden, hidden//2)
|
||||
w13_weight = torch.randint(0, 256, (num_experts, 2 * intermediate_hidden, hidden // 2),
|
||||
dtype=torch.uint8, device=device).view(torch.int8)
|
||||
w13_weight_scale = torch.randn(num_experts, 2 * intermediate_hidden, hidden // 16,
|
||||
device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn)
|
||||
w13_weight_scale_2 = torch.randn(num_experts, device=device).abs().clamp(0.5, 2.0)
|
||||
w13_input_scale = torch.ones(num_experts, device=device)
|
||||
|
||||
# w2: (num_experts, hidden, intermediate_hidden//2)
|
||||
w2_weight = torch.randint(0, 256, (num_experts, hidden, intermediate_hidden // 2),
|
||||
dtype=torch.uint8, device=device).view(torch.int8)
|
||||
w2_weight_scale = torch.randn(num_experts, hidden, intermediate_hidden // 16,
|
||||
device=device).abs().clamp(0.1, 10.0).to(torch.float8_e4m3fn)
|
||||
w2_weight_scale_2 = torch.randn(num_experts, device=device).abs().clamp(0.5, 2.0)
|
||||
w2_input_scale = torch.ones(num_experts, device=device)
|
||||
|
||||
# Transform weights for the kernel
|
||||
l1_weights, l2_weights = transform_nvfp4_weights_for_mega_moe(
|
||||
(w13_weight, w13_weight_scale),
|
||||
(w2_weight, w2_weight_scale),
|
||||
l1_weight_scale_2=w13_weight_scale_2,
|
||||
l2_weight_scale_2=w2_weight_scale_2,
|
||||
)
|
||||
|
||||
print(f"l1_weights: dtype={l1_weights[0].dtype} shape={l1_weights[0].shape} strides={l1_weights[0].stride()}")
|
||||
print(f"l1_sf: dtype={l1_weights[1].dtype} shape={l1_weights[1].shape} strides={l1_weights[1].stride()}")
|
||||
print(f"l2_weights: dtype={l2_weights[0].dtype} shape={l2_weights[0].shape} strides={l2_weights[0].stride()}")
|
||||
print(f"l2_sf: dtype={l2_weights[1].dtype} shape={l2_weights[1].shape} strides={l2_weights[1].stride()}")
|
||||
|
||||
# Create symm buffer
|
||||
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
|
||||
group, num_experts, num_tokens, top_k, hidden, intermediate_hidden)
|
||||
|
||||
# Create input (BF16)
|
||||
hidden_states = torch.randn(num_tokens, hidden, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Create topk weights/ids
|
||||
topk_weights = torch.softmax(torch.randn(num_tokens, top_k, device=device), dim=-1)
|
||||
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k), device=device)
|
||||
|
||||
# Stage inputs
|
||||
from deepseek_v4_staging import _stage_deepseek_v4_mega_moe_inputs
|
||||
# Actually, we can't import from vllm patch. Let's just manually set up the symm buffer.
|
||||
|
||||
# Output tensor
|
||||
y = torch.zeros(num_tokens, hidden, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Call the kernel
|
||||
print("Calling fp8_nvfp4_mega_moe...")
|
||||
try:
|
||||
fp8_nvfp4_mega_moe(
|
||||
y,
|
||||
l1_weights, l2_weights,
|
||||
symm_buffer,
|
||||
)
|
||||
print("SUCCESS! y stats: min={:.4f} max={:.4f} mean={:.4f}".format(
|
||||
y.min().item(), y.max().item(), y.mean().item()))
|
||||
except Exception as e:
|
||||
print(f"FAILED: {e}")
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_nvfp4_mega_moe()
|
||||
Reference in New Issue
Block a user