[Kernel] [V1] Improved performance for V1 Triton (ROCm) backend (#14152)

This commit is contained in:
Thomas Parnell
2025-03-06 16:39:16 +01:00
committed by GitHub
parent 4f27044aab
commit 6bd1dd9d26
4 changed files with 398 additions and 77 deletions

View File

@@ -6,8 +6,9 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
@@ -156,20 +157,22 @@ class ROCmAttentionImpl(AttentionImpl):
)
# Compute attention and update output up to `num_actual_tokens`.
context_attention_fwd(q=query[:num_actual_tokens],
k=key[:num_actual_tokens],
v=value[:num_actual_tokens],
o=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
k_cache=key_cache,
v_cache=value_cache,
b_loc=attn_metadata.block_table,
b_start_loc=attn_metadata.query_start_loc,
b_seq_len=attn_metadata.seq_lens,
max_input_len=attn_metadata.max_query_len,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale)
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=attn_metadata.block_table,
query_start_loc=attn_metadata.query_start_loc,
seq_lens=attn_metadata.seq_lens,
max_query_len=attn_metadata.max_query_len,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale)
return output