D1.4: Add external k_sub merge test for hd=512 (avoids slow in-kernel k_sub compilation)
This commit is contained in:
137
tests/unit/test_d1_hd512_merge.py
Normal file
137
tests/unit/test_d1_hd512_merge.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""D1.4 hd=512 test using external k_sub merge.
|
||||
|
||||
Instead of the k_sub path in the kernel (which causes 45+ min compilation),
|
||||
we call the kernel once per k_sub tile with Q and K pre-sliced.
|
||||
The online softmax merge (same as D5) combines the partial results.
|
||||
|
||||
The kernel always runs at k_tile=256 (same as hd=256, proven to compile fast).
|
||||
"""
|
||||
import torch, math, time
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
hd, n = 512, 128
|
||||
m = 128
|
||||
k_tile = 256
|
||||
n_k_sub = hd // k_tile # 2
|
||||
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# FP32 reference (full attention)
|
||||
qf = q[:, :, 0].float()
|
||||
kf = k[:, :, 0].float()
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
|
||||
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
|
||||
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
|
||||
ref_unnorm = attn_exp @ v.float()
|
||||
ref_norm = (attn_exp / attn_sum) @ v.float()
|
||||
ref_lse = (torch.log(attn_sum.squeeze(-1)) + attn_max.squeeze(-1))[0].item()
|
||||
|
||||
# Use the hd=256 kernel (no k_sub path) with k_tile=256
|
||||
# Call once per k_sub tile, merge results via online softmax
|
||||
kernel = FmhaKernel(head_dim=k_tile, s_k=n, normalize=False)
|
||||
pv_n_tile = kernel.pv_n_tile
|
||||
n_pv_tiles = kernel.n_pv_tiles
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
print(f'hd={hd}, k_tile={k_tile}, n_k_sub={n_k_sub}, pv_n_tile={pv_n_tile}', flush=True)
|
||||
print(f'Compiling k_tile={k_tile} kernel...', flush=True)
|
||||
|
||||
# Compile once with the first k_sub tile
|
||||
q0 = q[:, 0:k_tile, :].contiguous()
|
||||
k0 = k[:, 0:k_tile, :].contiguous()
|
||||
v_tile = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1)
|
||||
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
|
||||
mQ0 = ct.from_dlpack(q0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q0))
|
||||
mK0 = ct.from_dlpack(k0).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k0))
|
||||
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
|
||||
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
||||
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
|
||||
|
||||
t0 = time.time()
|
||||
compiled = cute.compile(kernel, mQ0, mK0, mV, mC, stream, mLSE)
|
||||
t1 = time.time()
|
||||
print(f'Compilation took {t1-t0:.1f}s', flush=True)
|
||||
|
||||
# Run each k_sub tile and accumulate via online softmax merge
|
||||
# LSE_i = ln(sum(exp(S_i - m_i))) + m_i (in natural log domain)
|
||||
# Merge: O = (exp(LSE_0 - LSE_max) * O_0 + exp(LSE_1 - LSE_max) * O_1) /
|
||||
# (exp(LSE_0 - LSE_max) + exp(LSE_1 - LSE_max))
|
||||
# where LSE_max = max(LSE_0, LSE_1)
|
||||
|
||||
# Collect (un-norm O, LSE) for each k_sub and each pv_tile
|
||||
all_o_unnorm = [] # list of (n_k_sub, hd) tensors
|
||||
all_lse = [] # list of (n_k_sub,) LSE values
|
||||
|
||||
for ks in range(n_k_sub):
|
||||
ks_start = ks * k_tile
|
||||
ks_end = ks_start + k_tile
|
||||
q_ks = q[:, ks_start:ks_end, :].contiguous()
|
||||
k_ks = k[:, ks_start:ks_end, :].contiguous()
|
||||
|
||||
o_ks = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_ks = None
|
||||
|
||||
for nt in range(n_pv_tiles):
|
||||
v_start = nt * pv_n_tile
|
||||
v_end = v_start + pv_n_tile
|
||||
v_tile = v[:, v_start:v_end].contiguous().unsqueeze(-1)
|
||||
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tensor.zero_()
|
||||
|
||||
mQ = ct.from_dlpack(q_ks).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q_ks))
|
||||
mK = ct.from_dlpack(k_ks).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k_ks))
|
||||
mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile))
|
||||
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
||||
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
|
||||
|
||||
compiled(mQ, mK, mV, mC, stream, mLSE)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
o_ks[:, v_start:v_end, :] = c_tile
|
||||
if nt == 0:
|
||||
lse_ks = lse_tensor[0, 0, 0].item()
|
||||
|
||||
all_o_unnorm.append(o_ks[:, :, 0].float())
|
||||
all_lse.append(lse_ks)
|
||||
print(f' k_sub={ks}: lse={lse_ks:.4f}', flush=True)
|
||||
|
||||
# Online softmax merge
|
||||
# O_unnorm_full = sum_ks exp(lse_ks - lse_max) * O_ks
|
||||
# Normalization: O_norm = O_unnorm_full / sum_ks exp(lse_ks - lse_max)
|
||||
lse_max = max(all_lse)
|
||||
o_merged_unnorm = torch.zeros(m, hd, dtype=torch.float32, device='cuda')
|
||||
denom = 0.0
|
||||
for ks in range(n_k_sub):
|
||||
w = math.exp(all_lse[ks] - lse_max)
|
||||
o_merged_unnorm += w * all_o_unnorm[ks]
|
||||
denom += w
|
||||
|
||||
o_merged_norm = o_merged_unnorm / denom
|
||||
|
||||
cos_unnorm = torch.nn.functional.cosine_similarity(
|
||||
o_merged_unnorm.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0)
|
||||
).item()
|
||||
cos_norm = torch.nn.functional.cosine_similarity(
|
||||
o_merged_norm.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)
|
||||
).item()
|
||||
|
||||
status = "PASS" if cos_norm >= 0.99 else "FAIL"
|
||||
print(f'\nhd=512 (external k_sub merge): cos_unnorm {cos_unnorm:.6f} cos_norm {cos_norm:.6f} {status}')
|
||||
if cos_norm < 0.99:
|
||||
print(f' o_merged[0,:4]={o_merged_norm[0,:4].tolist()}')
|
||||
print(f' ref[0,:4]={ref_norm[0,:4].tolist()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
Reference in New Issue
Block a user