"""D1: Test hd=64 only with CUDA_LAUNCH_BLOCKING for crash debug.""" import os os.environ['CUDA_LAUNCH_BLOCKING'] = '1' import torch, math import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda from dsv4.kernels.attention.fmha import FmhaKernel hd = 64; n = 128; m = 128 torch.manual_seed(42) q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda') v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda') c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') qf = q[:,:,0].float(); kf = k[:,:,0].float() scale = 1.0 / math.sqrt(hd) attn = qf @ kf.T * scale; attn = torch.softmax(attn, dim=-1) ref = attn @ v.float() v_kernel = v.unsqueeze(-1) mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) kernel = FmhaKernel(head_dim=hd, s_k=n) print(f'Compiling hd={hd}...', flush=True) compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) print(f'Running hd={hd}...', flush=True) compiled(mQ, mK, mV, mC, stream) torch.cuda.synchronize() out = c[:,:,0].float() cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() print(f'hd={hd}: cos {cos:.6f}')