[Kernel] [V1] Improved performance for V1 Triton (ROCm) backend (#14152)
This commit is contained in:
@@ -6,8 +6,9 @@ import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
||||
@@ -156,20 +157,22 @@ class ROCmAttentionImpl(AttentionImpl):
|
||||
)
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
context_attention_fwd(q=query[:num_actual_tokens],
|
||||
k=key[:num_actual_tokens],
|
||||
v=value[:num_actual_tokens],
|
||||
o=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
b_loc=attn_metadata.block_table,
|
||||
b_start_loc=attn_metadata.query_start_loc,
|
||||
b_seq_len=attn_metadata.seq_lens,
|
||||
max_input_len=attn_metadata.max_query_len,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale)
|
||||
chunked_prefill_paged_decode(
|
||||
query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=attn_metadata.block_table,
|
||||
query_start_loc=attn_metadata.query_start_loc,
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
max_query_len=attn_metadata.max_query_len,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale)
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user