[Kernel] Support DCP for Triton backend (#25132)

Signed-off-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
Wei Wei
2025-09-24 18:09:34 -07:00
committed by GitHub
parent 52d0cb8458
commit 05c19485a5
4 changed files with 30 additions and 8 deletions

View File

@@ -46,6 +46,8 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda")
attn_logits = torch.empty(
@@ -60,6 +62,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
k_buffer,
v_buffer,
o,
lse,
req_to_token,
b_seq_len,
attn_logits,
@@ -72,12 +75,14 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
o1 = torch.zeros_like(o)
lse1 = torch.zeros_like(lse)
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o1,
lse1,
req_to_page,
b_seq_len,
attn_logits,