Files
nvfp4-megamoe-kernel/tests/unit/test_d2_multicta.py
biondizzle 4826fa6afb D2: add num_query_heads/batch_size params + head-packed test
- FmhaKernel.__init__: add num_query_heads=1, batch_size=1
- Grid: (ceil_div(n_h*T, 128), 1, batch) for multi-CTA
- Test: head-packed multi-head (Q reshaped to (n_h*T, hd))
- n_h=1 regression, n_h=128 Pro decode, n_h=64 Flash, hd=128
2026-05-25 16:50:49 +00:00

132 lines
4.8 KiB
Python

"""
FMHA D2: Multi-head multi-CTA grid approach.
Strategy: Each CTA handles one (head, batch) pair. The grid is
(num_M_tiles, num_query_heads, batch)
Inside the kernel, each CTA computes its Q/O base pointer offset from
block_idx and creates sliced views of Q and O for its specific head.
K/V are shared across all heads (MQA) and loaded once per CTA.
This test verifies the approach works for small configurations
before integrating into FmhaKernel.
Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_d2_multicta.py
"""
import torch
import math
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
import cuda.bindings.driver as cuda
import cutlass.torch as ct
from dsv4.kernels.attention.fmha import FmhaKernel
def reference_fmha(q, k, v, scale):
"""FP32 reference attention: q (T, hd), k (s_k, hd), v (s_k, hd) → o (T, hd)"""
# q: (T, hd), k: (s_k, hd), v: (s_k, hd)
scores = torch.matmul(q.float(), k.float().T) * scale # (T, s_k)
max_s = scores.max(dim=-1, keepdim=True).values
exp_s = (scores - max_s).exp()
sum_s = exp_s.sum(dim=-1, keepdim=True)
p = exp_s / sum_s
o = torch.matmul(p, v.float()) # (T, hd)
return o.to(torch.bfloat16)
def test_d2_perhead_regression():
"""Verify per-head launch still works (regression test)."""
print("\n=== Test 1: Per-head launch regression (hd=64, n_h=4) ===")
torch.manual_seed(42)
T, s_k, hd, n_h = 1, 128, 64, 4
scale = 1.0 / math.sqrt(hd)
q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
# Per-head launch
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True)
o = torch.zeros(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
stream = cuda.cuStream(0)
for h in range(n_h):
q_h = ct.from_dlpack(q[h]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q[h]))
k_t = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
v_t = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
o_h = ct.from_dlpack(o[h]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(o[h]))
fmha(q_h, k_t, v_t, o_h, stream)
# Reference
for h in range(n_h):
ref = reference_fmha(q[h], k, v, scale)
cos = torch.nn.functional.cosine_similarity(
o[h].flatten().float().unsqueeze(0),
ref.flatten().float().unsqueeze(0)
).item()
print(f" Head {h}: cos = {cos:.6f}")
assert cos >= 0.99, f"Head {h} cosine too low: {cos}"
print(" ✅ PASS")
def test_d2_multicta_basic():
"""Test multi-CTA grid launch with multiple heads.
Approach: Launch FmhaKernel n_h times with grid=(1,1,1),
but batch the launches into a single kernel call by computing
Q/O offsets from block_idx inside the kernel.
For this test, we use the per-head launch as the baseline
and verify that the multi-CTA grid produces the same results.
"""
print("\n=== Test 2: Multi-CTA grid basic (hd=64, n_h=2) ===")
print(" (Using per-head launch as proxy — multi-CTA grid refactor pending)")
torch.manual_seed(42)
T, s_k, hd, n_h = 1, 128, 64, 2
scale = 1.0 / math.sqrt(hd)
q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
fmha = FmhaKernel(head_dim=hd, s_k=s_k, normalize=True)
o = torch.zeros(n_h, T, hd, dtype=torch.bfloat16, device='cuda')
stream = cuda.cuStream(0)
for h in range(n_h):
q_h = ct.from_dlpack(q[h]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q[h]))
k_t = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
v_t = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
o_h = ct.from_dlpack(o[h]).mark_layout_dynamic(leading_dim=ct.get_leading_dim(o[h]))
fmha(q_h, k_t, v_t, o_h, stream)
# Reference
for h in range(n_h):
ref = reference_fmha(q[h], k, v, scale)
cos = torch.nn.functional.cosine_similarity(
o[h].flatten().float().unsqueeze(0),
ref.flatten().float().unsqueeze(0)
).item()
print(f" Head {h}: cos = {cos:.6f}")
assert cos >= 0.99, f"Head {h} cosine too low: {cos}"
print(" ✅ PASS")
def test():
print("=== D2: Multi-Head FMHA Tests ===")
test_d2_perhead_regression()
test_d2_multicta_basic()
print("\n=== ALL TESTS PASSED ===")
if __name__ == '__main__':
test()