diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index e13b8cfd..23f071d6 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -92,12 +92,10 @@ class FmhaKernel: self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 - self.pv_n_tile = min(head_dim, 256) - # At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB, - # making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512 - # (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256. - if head_dim > 256: - self.pv_n_tile = 128 + # PV N=16 sub-tiles: avoid tcgen05.mma Layout D bug where N=64 + # skips TMEM columns 32-35 and 48-51. N=16 works for all HD values. + # More PV calls per K-tile, but each is correct. + self.pv_n_tile = 16 self.n_pv_tiles = head_dim // self.pv_n_tile self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) self.num_query_heads = num_query_heads diff --git a/tests/unit/test_fmha_pv16.py b/tests/unit/test_fmha_pv16.py new file mode 100644 index 00000000..f0e24d57 --- /dev/null +++ b/tests/unit/test_fmha_pv16.py @@ -0,0 +1,87 @@ +"""Test FMHA with pv_n_tile=16 (N=16 sub-tiles for PV GEMM). + +This tests the CuTeDSL FMHA kernel with the Layout D bug fix: +- pv_n_tile=16 avoids the tcgen05.mma N=64 bug (missing TMEM columns) +- Should work for HD=16, 64, 128, 256 with cosine ≥ 0.999 +""" +import torch +import math +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +from dsv4.kernels.attention.fmha import FmhaKernel +from dsv4.kernels.attention.production import dsv4_attention_per_head + + +def test_fmha_pv16_hd64(): + """HD=64 with pv_n_tile=16 (4 PV sub-tiles)""" + hd = 64 + sk = 128 + n_h = 1 + scale = 1.0 / math.sqrt(hd) + + torch.manual_seed(42) + q = torch.randn(1, sk, hd, dtype=torch.bfloat16, device='cuda') # head-packed + k = torch.randn(sk, hd, dtype=torch.bfloat16, device='cuda') + v = torch.randn(hd, sk, dtype=torch.bfloat16, device='cuda') + o = torch.zeros(1, sk, hd, dtype=torch.bfloat16, device='cuda') + + # FMHA kernel + dsv4_attention_per_head(q, k, v, o, sk, scale, swa_len=sk, is_causal=False) + + # Reference + q_ref = q[0].float() # (sk, hd) + k_ref = k.float() # (sk, hd) + v_ref = v.float().T # (hd, sk) → need (sk, hd) for matmul + + s = q_ref @ k_ref.T * scale # (sk, sk) + s_max = s.max(dim=-1, keepdim=True).values + p = torch.softmax(s - s_max, dim=-1) + o_ref = (p @ v_ref.T).to(torch.bfloat16) # (sk, hd) → bf16 + + # Compare row 0 + o_row0 = o[0, 0].float() + o_ref0 = o_ref[0].float() + + cs = torch.nn.functional.cosine_similarity(o_row0.unsqueeze(0), o_ref0.unsqueeze(0)).item() + print(f"HD={hd} pv_n_tile=16: cosine={cs:.8f}") + assert cs > 0.999, f"Cosine {cs} < 0.999" + print("PASSED") + + +def test_fmha_pv16_hd128(): + """HD=128 with pv_n_tile=16 (8 PV sub-tiles)""" + hd = 128 + sk = 128 + n_h = 1 + scale = 1.0 / math.sqrt(hd) + + torch.manual_seed(42) + q = torch.randn(1, sk, hd, dtype=torch.bfloat16, device='cuda') + k = torch.randn(sk, hd, dtype=torch.bfloat16, device='cuda') + v = torch.randn(hd, sk, dtype=torch.bfloat16, device='cuda') + o = torch.zeros(1, sk, hd, dtype=torch.bfloat16, device='cuda') + + dsv4_attention_per_head(q, k, v, o, sk, scale, swa_len=sk, is_causal=False) + + # Reference + q_ref = q[0].float() + k_ref = k.float() + v_ref = v.float() + s = q_ref @ k_ref.T * scale + p = torch.softmax(s - s.max(dim=-1, keepdim=True).values, dim=-1) + o_ref = (p @ v_ref.T).to(torch.bfloat16) + + o_row0 = o[0, 0].float() + o_ref0 = o_ref[0].float() + cs = torch.nn.functional.cosine_similarity(o_row0.unsqueeze(0), o_ref0.unsqueeze(0)).item() + print(f"HD={hd} pv_n_tile=16: cosine={cs:.8f}") + assert cs > 0.999, f"Cosine {cs} < 0.999" + print("PASSED") + + +if __name__ == '__main__': + test_fmha_pv16_hd64() + test_fmha_pv16_hd128()