auto: pre-test commit

This commit is contained in:
2026-05-28 19:08:04 +00:00
parent a723b524f7
commit 41343fdc6b
2 changed files with 91 additions and 6 deletions

View File

@@ -92,12 +92,10 @@ class FmhaKernel:
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
self.pv_n_tile = min(head_dim, 256)
# At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB,
# making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512
# (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256.
if head_dim > 256:
self.pv_n_tile = 128
# PV N=16 sub-tiles: avoid tcgen05.mma Layout D bug where N=64
# skips TMEM columns 32-35 and 48-51. N=16 works for all HD values.
# More PV calls per K-tile, but each is correct.
self.pv_n_tile = 16
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.num_query_heads = num_query_heads

View File

@@ -0,0 +1,87 @@
"""Test FMHA with pv_n_tile=16 (N=16 sub-tiles for PV GEMM).
This tests the CuTeDSL FMHA kernel with the Layout D bug fix:
- pv_n_tile=16 avoids the tcgen05.mma N=64 bug (missing TMEM columns)
- Should work for HD=16, 64, 128, 256 with cosine ≥ 0.999
"""
import torch
import math
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
from dsv4.kernels.attention.fmha import FmhaKernel
from dsv4.kernels.attention.production import dsv4_attention_per_head
def test_fmha_pv16_hd64():
"""HD=64 with pv_n_tile=16 (4 PV sub-tiles)"""
hd = 64
sk = 128
n_h = 1
scale = 1.0 / math.sqrt(hd)
torch.manual_seed(42)
q = torch.randn(1, sk, hd, dtype=torch.bfloat16, device='cuda') # head-packed
k = torch.randn(sk, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(hd, sk, dtype=torch.bfloat16, device='cuda')
o = torch.zeros(1, sk, hd, dtype=torch.bfloat16, device='cuda')
# FMHA kernel
dsv4_attention_per_head(q, k, v, o, sk, scale, swa_len=sk, is_causal=False)
# Reference
q_ref = q[0].float() # (sk, hd)
k_ref = k.float() # (sk, hd)
v_ref = v.float().T # (hd, sk) → need (sk, hd) for matmul
s = q_ref @ k_ref.T * scale # (sk, sk)
s_max = s.max(dim=-1, keepdim=True).values
p = torch.softmax(s - s_max, dim=-1)
o_ref = (p @ v_ref.T).to(torch.bfloat16) # (sk, hd) → bf16
# Compare row 0
o_row0 = o[0, 0].float()
o_ref0 = o_ref[0].float()
cs = torch.nn.functional.cosine_similarity(o_row0.unsqueeze(0), o_ref0.unsqueeze(0)).item()
print(f"HD={hd} pv_n_tile=16: cosine={cs:.8f}")
assert cs > 0.999, f"Cosine {cs} < 0.999"
print("PASSED")
def test_fmha_pv16_hd128():
"""HD=128 with pv_n_tile=16 (8 PV sub-tiles)"""
hd = 128
sk = 128
n_h = 1
scale = 1.0 / math.sqrt(hd)
torch.manual_seed(42)
q = torch.randn(1, sk, hd, dtype=torch.bfloat16, device='cuda')
k = torch.randn(sk, hd, dtype=torch.bfloat16, device='cuda')
v = torch.randn(hd, sk, dtype=torch.bfloat16, device='cuda')
o = torch.zeros(1, sk, hd, dtype=torch.bfloat16, device='cuda')
dsv4_attention_per_head(q, k, v, o, sk, scale, swa_len=sk, is_causal=False)
# Reference
q_ref = q[0].float()
k_ref = k.float()
v_ref = v.float()
s = q_ref @ k_ref.T * scale
p = torch.softmax(s - s.max(dim=-1, keepdim=True).values, dim=-1)
o_ref = (p @ v_ref.T).to(torch.bfloat16)
o_row0 = o[0, 0].float()
o_ref0 = o_ref[0].float()
cs = torch.nn.functional.cosine_similarity(o_row0.unsqueeze(0), o_ref0.unsqueeze(0)).item()
print(f"HD={hd} pv_n_tile=16: cosine={cs:.8f}")
assert cs > 0.999, f"Cosine {cs} < 0.999"
print("PASSED")
if __name__ == '__main__':
test_fmha_pv16_hd64()
test_fmha_pv16_hd128()