D1.5: Add isolated round-trip test comparing s_k=128 vs s_k=256 with NOOP rescale
This commit is contained in:
@@ -418,18 +418,26 @@ class FmhaKernel:
|
||||
# sub-tiles, and building BOTH copies from the SAME tensor, ensures the
|
||||
# column mappings agree on round-trip.
|
||||
# ============================================================
|
||||
corr_tile_size = 32 # Must be power of 2, divides head_dim. Try 32 instead of 16.
|
||||
tOtO_i_layout = cute.composition(
|
||||
tOtO0.layout, cute.make_layout((128, corr_tile_size))
|
||||
)
|
||||
corr_tile_size = 16 # Must be power of 2, divides head_dim
|
||||
# Try both composition and raw layout
|
||||
use_comp = True
|
||||
if const_expr(use_comp):
|
||||
tOtO_i_layout = cute.composition(
|
||||
tOtO0.layout, cute.make_layout((128, corr_tile_size))
|
||||
)
|
||||
else:
|
||||
tOtO_i_layout = tOtO0.layout
|
||||
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
|
||||
|
||||
# Coordinate tensor for O (needed for partition_D of load)
|
||||
cO = cute.make_identity_tensor((128, self.head_dim))
|
||||
tOcO = pv_thr.partition_C(cO)
|
||||
tOcO_i_layout = cute.composition(
|
||||
tOcO.layout, cute.make_layout((128, corr_tile_size))
|
||||
)
|
||||
if const_expr(use_comp):
|
||||
tOcO_i_layout = cute.composition(
|
||||
tOcO.layout, cute.make_layout((128, corr_tile_size))
|
||||
)
|
||||
else:
|
||||
tOcO_i_layout = tOcO.layout
|
||||
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
|
||||
|
||||
tmem_load_o_atom = cute.make_copy_atom(
|
||||
|
||||
108
tests/unit/test_d15_roundtrip_iso.py
Normal file
108
tests/unit/test_d15_roundtrip_iso.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
D1.5 Minimal TMEM round-trip test within FMHA.
|
||||
|
||||
Tests: run FMHA with s_k=128, then add a NOOP correction_rescale
|
||||
(multiply by 1.0) BETWEEN the softmax and the epilogue.
|
||||
If the output is bitwise identical to without rescale, the round-trip works.
|
||||
If it differs, the round-trip corrupts data.
|
||||
|
||||
This test directly compares: same kernel, same data, WITH and WITHOUT
|
||||
the O rescale step (forced to NOOP).
|
||||
"""
|
||||
import torch, math
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
|
||||
def test():
|
||||
hd = 64; m = 128; s_k = 128; scale = 1.0 / math.sqrt(hd)
|
||||
torch.manual_seed(42)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# Run WITHOUT rescale (s_k=128, n_kv_tiles=1, rescale code is dead)
|
||||
kernel1 = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=False, normalize=False)
|
||||
pv_n_tile = kernel1.pv_n_tile
|
||||
c1 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
rs1 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
v_t = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
|
||||
mV = ct.from_dlpack(v_t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_t))
|
||||
mC1 = ct.from_dlpack(c1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c1))
|
||||
mLSE1 = ct.from_dlpack(lse1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse1))
|
||||
mRS1 = ct.from_dlpack(rs1).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs1))
|
||||
|
||||
print('Compiling s_k=128 (no rescale)...', flush=True)
|
||||
comp1 = cute.compile(kernel1, mQ, mK, mV, mC1, stream, mLSE1, row_sums=mRS1)
|
||||
comp1(mQ, mK, mV, mC1, stream, mLSE1, row_sums=mRS1)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Run WITH NOOP rescale (s_k=256, but with debug_noop_rescale=True,
|
||||
# using same K/V repeated to make the second segment identical)
|
||||
# This triggers the O rescale code path (n_kv_tiles=2) but with factor=1.0
|
||||
k2 = torch.cat([k, k], dim=0)
|
||||
v2 = torch.cat([v, v], dim=0)
|
||||
|
||||
kernel2 = FmhaKernel(head_dim=hd, s_k=256, use_smem_p=False, normalize=False, debug_noop_rescale=True)
|
||||
c2 = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
rs2 = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
mK2 = ct.from_dlpack(k2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k2))
|
||||
v2_t = v2[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
|
||||
mV2 = ct.from_dlpack(v2_t).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v2_t))
|
||||
mC2 = ct.from_dlpack(c2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c2))
|
||||
mLSE2 = ct.from_dlpack(lse2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse2))
|
||||
mRS2 = ct.from_dlpack(rs2).mark_layout_dynamic(leading_dim=ct.get_leading_dim(rs2))
|
||||
|
||||
print('Compiling s_k=256 (NOOP rescale)...', flush=True)
|
||||
comp2 = cute.compile(kernel2, mQ, mK2, mV2, mC2, stream, mLSE2, row_sums=mRS2)
|
||||
comp2(mQ, mK2, mV2, mC2, stream, mLSE2, row_sums=mRS2)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# FP32 reference
|
||||
qf = q[:, :, 0].float()
|
||||
kf = k[:, :, 0].float()
|
||||
attn = qf @ kf.T * scale
|
||||
attn_max = attn.max(dim=-1, keepdim=True)[0]
|
||||
attn_exp = torch.exp(attn - attn_max)
|
||||
ref_norm = (attn_exp @ v.float()) / attn_exp.sum(dim=-1, keepdim=True)
|
||||
|
||||
# Normalize both outputs
|
||||
out1 = c1[:, :, 0].float() / rs1[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
|
||||
out2 = c2[:, :, 0].float() / rs2[:, 0, 0].float().unsqueeze(1).clamp(min=1e-30)
|
||||
|
||||
cos1 = torch.nn.functional.cosine_similarity(out1.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)).item()
|
||||
cos2 = torch.nn.functional.cosine_similarity(out2.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)).item()
|
||||
cos_12 = torch.nn.functional.cosine_similarity(out1.flatten().unsqueeze(0), out2.flatten().unsqueeze(0)).item()
|
||||
|
||||
# Also check unnormalized
|
||||
unnorm1 = c1[:, :, 0].float()
|
||||
unnorm2 = c2[:, :, 0].float()
|
||||
# With identical segments, s_k=256 unnorm should be 2x s_k=128 unnorm
|
||||
# (row_sum also doubles, so normalized result is the same)
|
||||
ratio = unnorm2 / unnorm1.clamp(min=1e-10)
|
||||
print(f'Unnorm ratio (should be ~2.0): mean={ratio.mean().item():.4f} std={ratio.std().item():.4f}')
|
||||
|
||||
print(f's_k=128 (no rescale): cos={cos1:.6f}')
|
||||
print(f's_k=256 (NOOP rescale): cos={cos2:.6f}')
|
||||
print(f's_k=128 vs s_k=256: cos={cos_12:.6f}')
|
||||
|
||||
# Bitwise comparison on unnormalized O
|
||||
# s_k=256 with identical segments: O_unnorm = 2 * O_unnorm_128
|
||||
expected_unnorm2 = 2.0 * unnorm1
|
||||
bit_cos = torch.nn.functional.cosine_similarity(unnorm2.flatten().unsqueeze(0), expected_unnorm2.flatten().unsqueeze(0)).item()
|
||||
max_diff = (unnorm2 - expected_unnorm2).abs().max().item()
|
||||
print(f'Unnorm: expected 2x, got cos={bit_cos:.6f} max_diff={max_diff:.4f}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
Reference in New Issue
Block a user