D1.5: Add isolated round-trip test comparing s_k=128 vs s_k=256 with NOOP rescale

This commit is contained in:
2026-05-26 20:45:58 +00:00
parent e35b30dae6
commit 42c5793add
2 changed files with 123 additions and 7 deletions

View File

@@ -418,18 +418,26 @@ class FmhaKernel:
# sub-tiles, and building BOTH copies from the SAME tensor, ensures the
# column mappings agree on round-trip.
# ============================================================
corr_tile_size = 32 # Must be power of 2, divides head_dim. Try 32 instead of 16.
tOtO_i_layout = cute.composition(
tOtO0.layout, cute.make_layout((128, corr_tile_size))
)
corr_tile_size = 16 # Must be power of 2, divides head_dim
# Try both composition and raw layout
use_comp = True
if const_expr(use_comp):
tOtO_i_layout = cute.composition(
tOtO0.layout, cute.make_layout((128, corr_tile_size))
)
else:
tOtO_i_layout = tOtO0.layout
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
# Coordinate tensor for O (needed for partition_D of load)
cO = cute.make_identity_tensor((128, self.head_dim))
tOcO = pv_thr.partition_C(cO)
tOcO_i_layout = cute.composition(
tOcO.layout, cute.make_layout((128, corr_tile_size))
)
if const_expr(use_comp):
tOcO_i_layout = cute.composition(
tOcO.layout, cute.make_layout((128, corr_tile_size))
)
else:
tOcO_i_layout = tOcO.layout
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
tmem_load_o_atom = cute.make_copy_atom(

View File

@@ -0,0 +1,108 @@
"""
D1.5 Minimal TMEM round-trip test within FMHA.
Tests: run FMHA with s_k=128, then add a NOOP correction_rescale
(multiply by 1.0) BETWEEN the softmax and the epilogue.
If the output is bitwise identical to without rescale, the round-trip works.
If it differs, the round-trip corrupts data.
This test directly compares: same kernel, same data, WITH and WITHOUT
the O rescale step (forced to NOOP).
"""
import torch, math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test():
hd = 64; m = 128; s_k = 128; scale = 1.0 / math.sqrt(hd)
torch.manual_seed(42)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
# Run WITHOUT rescale (s_k=128, n_kv_tiles=1, rescale code is dead)
kernel1 = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False)
pv_n_tile = kernel1.pv_n_tile
c1 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
v_t = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
mV = ct.from_dlpack(v_t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_t))
mC1 = ct.from_dlpack(c1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c1))
mLSE1 = ct.from_dlpack(lse1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse1))
mRS1 = ct.from_dlpack(rs1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs1))
print('Compiling s_k=128 (no rescale)...', flush=True)
comp1 = cute.compile(kernel1, mQ, mK, mV, mC1, stream, mLSE1, row_sums=mRS1)
comp1(mQ, mK, mV, mC1, stream, mLSE1, row_sums=mRS1)
torch.cuda.synchronize()
# Run WITH NOOP rescale (s_k=256, but with debug_noop_rescale=True,
# using same K/V repeated to make the second segment identical)
# This triggers the O rescale code path (n_kv_tiles=2) but with factor=1.0
k2 = torch.cat([k, k], dim=0)
v2 = torch.cat([v, v], dim=0)
kernel2 = FmhaKernel(head_dim=hd, s_k=256, use_smem_p=False, normalize=False, debug_noop_rescale=True)
c2 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
rs2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
mK2 = ct.from_dlpack(k2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2))
v2_t = v2[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
mV2 = ct.from_dlpack(v2_t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v2_t))
mC2 = ct.from_dlpack(c2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c2))
mLSE2 = ct.from_dlpack(lse2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse2))
mRS2 = ct.from_dlpack(rs2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs2))
print('Compiling s_k=256 (NOOP rescale)...', flush=True)
comp2 = cute.compile(kernel2, mQ, mK2, mV2, mC2, stream, mLSE2, row_sums=mRS2)
comp2(mQ, mK2, mV2, mC2, stream, mLSE2, row_sums=mRS2)
torch.cuda.synchronize()
# FP32 reference
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
attn = qf @ kf.T * scale
attn_max = attn.max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(attn - attn_max)
ref_norm = (attn_exp @ v.float()) / attn_exp.sum(dim=-1, keepdim=True)
# Normalize both outputs
out1 = c1[:, :, 0].float() / rs1[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
out2 = c2[:, :, 0].float() / rs2[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
cos1 = torch.nn.functional.cosine_similarity(out1.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)).item()
cos2 = torch.nn.functional.cosine_similarity(out2.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)).item()
cos_12 = torch.nn.functional.cosine_similarity(out1.flatten().unsqueeze(0), out2.flatten().unsqueeze(0)).item()
# Also check unnormalized
unnorm1 = c1[:, :, 0].float()
unnorm2 = c2[:, :, 0].float()
# With identical segments, s_k=256 unnorm should be 2x s_k=128 unnorm
# (row_sum also doubles, so normalized result is the same)
ratio = unnorm2 / unnorm1.clamp(min=1e-10)
print(f'Unnorm ratio (should be ~2.0): mean={ratio.mean().item():.4f} std={ratio.std().item():.4f}')
print(f's_k=128 (no rescale): cos={cos1:.6f}')
print(f's_k=256 (NOOP rescale): cos={cos2:.6f}')
print(f's_k=128 vs s_k=256: cos={cos_12:.6f}')
# Bitwise comparison on unnormalized O
# s_k=256 with identical segments: O_unnorm = 2 * O_unnorm_128
expected_unnorm2 = 2.0 * unnorm1
bit_cos = torch.nn.functional.cosine_similarity(unnorm2.flatten().unsqueeze(0), expected_unnorm2.flatten().unsqueeze(0)).item()
max_diff = (unnorm2 - expected_unnorm2).abs().max().item()
print(f'Unnorm: expected 2x, got cos={bit_cos:.6f} max_diff={max_diff:.4f}')
if __name__ == '__main__':
test()