D1.4: Add external k_sub merge test for hd=512 (avoids slow in-kernel k_sub compilation)

This commit is contained in:
2026-05-24 16:31:06 +00:00
parent 13fcf16b14
commit e6c9e6c0d0

View File

@@ -0,0 +1,137 @@
"""D1.4 hd=512 test using external k_sub merge.
Instead of the k_sub path in the kernel (which causes 45+ min compilation),
we call the kernel once per k_sub tile with Q and K pre-sliced.
The online softmax merge (same as D5) combines the partial results.
The kernel always runs at k_tile=256 (same as hd=256, proven to compile fast).
"""
import torch, math, time
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test():
torch.manual_seed(42)
hd, n = 512, 128
m = 128
k_tile = 256
n_k_sub = hd // k_tile # 2
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
# FP32 reference (full attention)
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref_unnorm = attn_exp @ v.float()
ref_norm = (attn_exp / attn_sum) @ v.float()
ref_lse = (torch.log(attn_sum.squeeze(-1)) + attn_max.squeeze(-1))[0].item()
# Use the hd=256 kernel (no k_sub path) with k_tile=256
# Call once per k_sub tile, merge results via online softmax
kernel = FmhaKernel(head_dim=k_tile, s_k=n, normalize=False)
pv_n_tile = kernel.pv_n_tile
n_pv_tiles = kernel.n_pv_tiles
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
print(f'hd={hd}, k_tile={k_tile}, n_k_sub={n_k_sub}, pv_n_tile={pv_n_tile}', flush=True)
print(f'Compiling k_tile={k_tile} kernel...', flush=True)
# Compile once with the first k_sub tile
q0 = q[:, 0:k_tile, :].contiguous()
k0 = k[:, 0:k_tile, :].contiguous()
v_tile = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
mQ0 = ct.from_dlpack(q0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q0))
mK0 = ct.from_dlpack(k0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k0))
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
t0 = time.time()
compiled = cute.compile(kernel, mQ0, mK0, mV, mC, stream, mLSE)
t1 = time.time()
print(f'Compilation took {t1-t0:.1f}s', flush=True)
# Run each k_sub tile and accumulate via online softmax merge
# LSE_i = ln(sum(exp(S_i - m_i))) + m_i (in natural log domain)
# Merge: O = (exp(LSE_0 - LSE_max) * O_0 + exp(LSE_1 - LSE_max) * O_1) /
# (exp(LSE_0 - LSE_max) + exp(LSE_1 - LSE_max))
# where LSE_max = max(LSE_0, LSE_1)
# Collect (un-norm O, LSE) for each k_sub and each pv_tile
all_o_unnorm = [] # list of (n_k_sub, hd) tensors
all_lse = [] # list of (n_k_sub,) LSE values
for ks in range(n_k_sub):
ks_start = ks * k_tile
ks_end = ks_start + k_tile
q_ks = q[:, ks_start:ks_end, :].contiguous()
k_ks = k[:, ks_start:ks_end, :].contiguous()
o_ks = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
lse_ks = None
for nt in range(n_pv_tiles):
v_start = nt * pv_n_tile
v_end = v_start + pv_n_tile
v_tile = v[:, v_start:v_end].contiguous().unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor.zero_()
mQ = ct.from_dlpack(q_ks).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_ks))
mK = ct.from_dlpack(k_ks).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_ks))
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
compiled(mQ, mK, mV, mC, stream, mLSE)
torch.cuda.synchronize()
o_ks[:, v_start:v_end, :] = c_tile
if nt == 0:
lse_ks = lse_tensor[0, 0, 0].item()
all_o_unnorm.append(o_ks[:, :, 0].float())
all_lse.append(lse_ks)
print(f' k_sub={ks}: lse={lse_ks:.4f}', flush=True)
# Online softmax merge
# O_unnorm_full = sum_ks exp(lse_ks - lse_max) * O_ks
# Normalization: O_norm = O_unnorm_full / sum_ks exp(lse_ks - lse_max)
lse_max = max(all_lse)
o_merged_unnorm = torch.zeros(m, hd, dtype=torch.float32, device='cuda')
denom = 0.0
for ks in range(n_k_sub):
w = math.exp(all_lse[ks] - lse_max)
o_merged_unnorm += w * all_o_unnorm[ks]
denom += w
o_merged_norm = o_merged_unnorm / denom
cos_unnorm = torch.nn.functional.cosine_similarity(
o_merged_unnorm.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0)
).item()
cos_norm = torch.nn.functional.cosine_similarity(
o_merged_norm.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)
).item()
status = "PASS" if cos_norm >= 0.99 else "FAIL"
print(f'\nhd=512 (external k_sub merge): cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f} {status}')
if cos_norm < 0.99:
print(f' o_merged[0,:4]={o_merged_norm[0,:4].tolist()}')
print(f' ref[0,:4]={ref_norm[0,:4].tolist()}')
if __name__ == '__main__':
test()