[Attention][V1] Toggle for v1 attention backend (#18275)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
c9479b2920
commit
da4b69d0b4
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
@@ -126,6 +127,8 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
"TritonAttentionImpl")
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
self.force_prefill_decode_attn = \
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -166,9 +169,9 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||
use_prefill_decode_attn = (num_queries_per_kv &
|
||||
(num_queries_per_kv - 1)) != 0
|
||||
|
||||
num_q_is_pow2 = (num_queries_per_kv & (num_queries_per_kv - 1)) == 0
|
||||
use_prefill_decode_attn = (self.force_prefill_decode_attn
|
||||
or not num_q_is_pow2)
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if use_prefill_decode_attn:
|
||||
|
||||
Reference in New Issue
Block a user