auto: pre-test commit
This commit is contained in:
@@ -92,12 +92,10 @@ class FmhaKernel:
|
||||
self.head_dim = head_dim
|
||||
self.s_k = s_k
|
||||
self.n_kv_tiles = s_k // 128
|
||||
self.pv_n_tile = min(head_dim, 256)
|
||||
# At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB,
|
||||
# making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512
|
||||
# (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256.
|
||||
if head_dim > 256:
|
||||
self.pv_n_tile = 128
|
||||
# PV N=16 sub-tiles: avoid tcgen05.mma Layout D bug where N=64
|
||||
# skips TMEM columns 32-35 and 48-51. N=16 works for all HD values.
|
||||
# More PV calls per K-tile, but each is correct.
|
||||
self.pv_n_tile = 16
|
||||
self.n_pv_tiles = head_dim // self.pv_n_tile
|
||||
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
|
||||
self.num_query_heads = num_query_heads
|
||||
|
||||
87
tests/unit/test_fmha_pv16.py
Normal file
87
tests/unit/test_fmha_pv16.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Test FMHA with pv_n_tile=16 (N=16 sub-tiles for PV GEMM).
|
||||
|
||||
This tests the CuTeDSL FMHA kernel with the Layout D bug fix:
|
||||
- pv_n_tile=16 avoids the tcgen05.mma N=64 bug (missing TMEM columns)
|
||||
- Should work for HD=16, 64, 128, 256 with cosine ≥ 0.999
|
||||
"""
|
||||
import torch
|
||||
import math
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
from dsv4.kernels.attention.production import dsv4_attention_per_head
|
||||
|
||||
|
||||
def test_fmha_pv16_hd64():
|
||||
"""HD=64 with pv_n_tile=16 (4 PV sub-tiles)"""
|
||||
hd = 64
|
||||
sk = 128
|
||||
n_h = 1
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
|
||||
torch.manual_seed(42)
|
||||
q = torch.randn(1, sk, hd, dtype=torch.bfloat16, device='cuda') # head-packed
|
||||
k = torch.randn(sk, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(hd, sk, dtype=torch.bfloat16, device='cuda')
|
||||
o = torch.zeros(1, sk, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# FMHA kernel
|
||||
dsv4_attention_per_head(q, k, v, o, sk, scale, swa_len=sk, is_causal=False)
|
||||
|
||||
# Reference
|
||||
q_ref = q[0].float() # (sk, hd)
|
||||
k_ref = k.float() # (sk, hd)
|
||||
v_ref = v.float().T # (hd, sk) → need (sk, hd) for matmul
|
||||
|
||||
s = q_ref @ k_ref.T * scale # (sk, sk)
|
||||
s_max = s.max(dim=-1, keepdim=True).values
|
||||
p = torch.softmax(s - s_max, dim=-1)
|
||||
o_ref = (p @ v_ref.T).to(torch.bfloat16) # (sk, hd) → bf16
|
||||
|
||||
# Compare row 0
|
||||
o_row0 = o[0, 0].float()
|
||||
o_ref0 = o_ref[0].float()
|
||||
|
||||
cs = torch.nn.functional.cosine_similarity(o_row0.unsqueeze(0), o_ref0.unsqueeze(0)).item()
|
||||
print(f"HD={hd} pv_n_tile=16: cosine={cs:.8f}")
|
||||
assert cs > 0.999, f"Cosine {cs} < 0.999"
|
||||
print("PASSED")
|
||||
|
||||
|
||||
def test_fmha_pv16_hd128():
|
||||
"""HD=128 with pv_n_tile=16 (8 PV sub-tiles)"""
|
||||
hd = 128
|
||||
sk = 128
|
||||
n_h = 1
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
|
||||
torch.manual_seed(42)
|
||||
q = torch.randn(1, sk, hd, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(sk, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(hd, sk, dtype=torch.bfloat16, device='cuda')
|
||||
o = torch.zeros(1, sk, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
dsv4_attention_per_head(q, k, v, o, sk, scale, swa_len=sk, is_causal=False)
|
||||
|
||||
# Reference
|
||||
q_ref = q[0].float()
|
||||
k_ref = k.float()
|
||||
v_ref = v.float()
|
||||
s = q_ref @ k_ref.T * scale
|
||||
p = torch.softmax(s - s.max(dim=-1, keepdim=True).values, dim=-1)
|
||||
o_ref = (p @ v_ref.T).to(torch.bfloat16)
|
||||
|
||||
o_row0 = o[0, 0].float()
|
||||
o_ref0 = o_ref[0].float()
|
||||
cs = torch.nn.functional.cosine_similarity(o_row0.unsqueeze(0), o_ref0.unsqueeze(0)).item()
|
||||
print(f"HD={hd} pv_n_tile=16: cosine={cs:.8f}")
|
||||
assert cs > 0.999, f"Cosine {cs} < 0.999"
|
||||
print("PASSED")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fmha_pv16_hd64()
|
||||
test_fmha_pv16_hd128()
|
||||
Reference in New Issue
Block a user