Files
nvfp4-megamoe-kernel/tests/unit/test_d15_in_kernel_rescale.py
biondizzle bf2c7c8bb8 D1.5: Implement in-kernel O rescale via CUTLASS correction_rescale pattern
- Both load and store atoms built from SAME tOtO_i (composition-tiled)
- Same Repetition(corr_tile_size=16) for both copies
- pv_done_bar synchronization between MMA and softmax warps
- acc_scale computed per kt iteration, used to rescale O in TMEM
- const_expr(n_kv_tiles > 1) guards for zero overhead at s_k=128
- New test: test_d15_in_kernel_rescale.py (s_k=128/256/384)
- Minimal roundtrip test: test_tmem_roundtrip_minimal.py
2026-05-26 20:26:06 +00:00

153 lines
6.4 KiB
Python

"""
D1.5 Phase 4: Test in-kernel O rescale for multi-KV-tile FMHA.
Tests the CUTLASS correction_rescale pattern:
- Both load and store atoms built from the SAME tOtO_i (composition-tiled)
- Same Repetition(corr_tile_size=16) for both
- Rescale O in TMEM between PV iterations
Compares against:
1. FP32 reference (ground truth)
2. Python KV merge (proven correct, cos 0.999998)
3. s_k=128 baseline (no rescale, regression check)
"""
import torch, math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def reference_attention(q, k, v, scale):
"""FP32 reference: returns un-normalized O."""
qf = q.float()
kf = k.float()
attn = qf @ kf.T * scale
attn_max = attn.max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(attn - attn_max)
ref_unnorm = attn_exp @ v.float()
return ref_unnorm
def run_fmha(q, k, v, head_dim, s_k, pv_n_tile, use_smem_p, stream, lse_tensor, row_sums_tensor):
"""Run FMHA kernel and return output tensor."""
m = q.shape[0]
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor))
kernel = FmhaKernel(head_dim=head_dim, s_k=s_k, use_smem_p=use_smem_p, normalize=False)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE, row_sums=mRS)
compiled(mQ, mK, mV, mC, stream, mLSE, row_sums=mRS)
return c_tile, lse_tensor, row_sums_tensor, kernel
def test():
hd = 64
m = 128
scale = 1.0 / math.sqrt(hd)
torch.manual_seed(42)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
# ===== Test 1: s_k=128 baseline (no rescale) =====
s_k1 = 128
k1 = torch.randn(s_k1, hd, 1, dtype=torch.bfloat16, device='cuda')
v1 = torch.randn(s_k1, hd, dtype=torch.bfloat16, device='cuda')
lse1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
# Need a dummy run to get pv_n_tile
kernel0 = FmhaKernel(head_dim=hd, s_k=s_k1, use_smem_p=False, normalize=False)
pv_n_tile = kernel0.pv_n_tile
c1, lse1, rs1, _ = run_fmha(q, k1, v1, hd, s_k1, pv_n_tile, False, stream, lse1, rs1)
torch.cuda.synchronize()
ref1 = reference_attention(q[:, :, 0], k1[:, :, 0], v1, scale)
cos1 = torch.nn.functional.cosine_similarity(
c1[:, :, 0].float().flatten().unsqueeze(0), ref1.flatten().unsqueeze(0)
).item()
status1 = "PASS" if cos1 >= 0.999 else "FAIL"
print(f'Test 1: s_k=128 baseline: cos={cos1:.6f} {status1}', flush=True)
# ===== Test 2: s_k=256 with in-kernel rescale (CUTLASS correction_rescale) =====
s_k2 = 256
k2 = torch.randn(s_k2, hd, 1, dtype=torch.bfloat16, device='cuda')
v2 = torch.randn(s_k2, hd, dtype=torch.bfloat16, device='cuda')
lse2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
c2, lse2, rs2, _ = run_fmha(q, k2, v2, hd, s_k2, pv_n_tile, False, stream, lse2, rs2)
torch.cuda.synchronize()
ref2 = reference_attention(q[:, :, 0], k2[:, :, 0], v2, scale)
cos2 = torch.nn.functional.cosine_similarity(
c2[:, :, 0].float().flatten().unsqueeze(0), ref2.flatten().unsqueeze(0)
).item()
status2 = "PASS" if cos2 >= 0.999 else "FAIL"
print(f'Test 2: s_k=256 in-kernel rescale: cos={cos2:.6f} {status2}', flush=True)
# ===== Test 3: Python KV merge (oracle) =====
c_s0 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_s0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs_s0 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
c_s0, lse_s0, rs_s0, _ = run_fmha(q, k2[:128], v2[:128], hd, 128, pv_n_tile, False, stream, lse_s0, rs_s0)
c_s1 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_s1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs_s1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
c_s1, lse_s1, rs_s1, _ = run_fmha(q, k2[128:], v2[128:], hd, 128, pv_n_tile, False, stream, lse_s1, rs_s1)
torch.cuda.synchronize()
# D5 merge: O = sum(exp(lse_i) * O_i_norm) / sum(exp(lse_i))
o0 = c_s0[:, :, 0].float()
o1 = c_s1[:, :, 0].float()
r0 = rs_s0[:, 0, 0].float()
r1 = rs_s1[:, 0, 0].float()
l0 = lse_s0[:, 0, 0].float()
l1 = lse_s1[:, 0, 0].float()
o0_norm = o0 / r0.unsqueeze(1).clamp(min=1e-30)
o1_norm = o1 / r1.unsqueeze(1).clamp(min=1e-30)
w0 = torch.exp(l0).unsqueeze(1)
w1 = torch.exp(l1).unsqueeze(1)
oracle = (w0 * o0_norm + w1 * o1_norm) / (w0 + w1)
cos_oracle = torch.nn.functional.cosine_similarity(
oracle.flatten().unsqueeze(0), ref2.flatten().unsqueeze(0)
).item()
print(f'Oracle: Python KV merge: cos={cos_oracle:.6f}', flush=True)
# ===== Test 4: s_k=384 (3 KV tiles) =====
s_k3 = 384
k3 = torch.randn(s_k3, hd, 1, dtype=torch.bfloat16, device='cuda')
v3 = torch.randn(s_k3, hd, dtype=torch.bfloat16, device='cuda')
lse3 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs3 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
c3, lse3, rs3, _ = run_fmha(q, k3, v3, hd, s_k3, pv_n_tile, False, stream, lse3, rs3)
torch.cuda.synchronize()
ref3 = reference_attention(q[:, :, 0], k3[:, :, 0], v3, scale)
cos3 = torch.nn.functional.cosine_similarity(
c3[:, :, 0].float().flatten().unsqueeze(0), ref3.flatten().unsqueeze(0)
).item()
status3 = "PASS" if cos3 >= 0.999 else "FAIL"
print(f'Test 4: s_k=384 in-kernel rescale: cos={cos3:.6f} {status3}', flush=True)
# ===== Summary =====
all_pass = cos1 >= 0.999 and cos2 >= 0.999 and cos3 >= 0.999
print(f'\nSummary: {"ALL PASS ✅" if all_pass else "SOME FAIL ❌"}', flush=True)
if __name__ == '__main__':
test()