[Attention][V1] Toggle for v1 attention backend (#18275)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
c9479b2920
commit
da4b69d0b4
@@ -264,8 +264,8 @@ def chunked_prefill_paged_decode(
|
|||||||
# Conversion of FP8 Tensor from uint8 storage to
|
# Conversion of FP8 Tensor from uint8 storage to
|
||||||
# appropriate torch.dtype for interpretation by Triton
|
# appropriate torch.dtype for interpretation by Triton
|
||||||
if "fp8" in kv_cache_dtype:
|
if "fp8" in kv_cache_dtype:
|
||||||
assert key_cache.dtype == torch.uint8
|
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||||
assert value_cache.dtype == torch.uint8
|
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||||
|
|
||||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
target_dtype = current_platform.fp8_dtype()
|
target_dtype = current_platform.fp8_dtype()
|
||||||
|
|||||||
@@ -744,8 +744,8 @@ def context_attention_fwd(q,
|
|||||||
# Conversion of FP8 Tensor from uint8 storage to
|
# Conversion of FP8 Tensor from uint8 storage to
|
||||||
# appropriate torch.dtype for interpretation by Triton
|
# appropriate torch.dtype for interpretation by Triton
|
||||||
if "fp8" in kv_cache_dtype:
|
if "fp8" in kv_cache_dtype:
|
||||||
assert (k_cache.dtype == torch.uint8)
|
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||||
assert (v_cache.dtype == torch.uint8)
|
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
|
||||||
|
|
||||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||||
target_dtype = current_platform.fp8_dtype()
|
target_dtype = current_platform.fp8_dtype()
|
||||||
|
|||||||
12
vllm/envs.py
12
vllm/envs.py
@@ -15,6 +15,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_NCCL_SO_PATH: Optional[str] = None
|
VLLM_NCCL_SO_PATH: Optional[str] = None
|
||||||
LD_LIBRARY_PATH: Optional[str] = None
|
LD_LIBRARY_PATH: Optional[str] = None
|
||||||
VLLM_USE_TRITON_FLASH_ATTN: bool = False
|
VLLM_USE_TRITON_FLASH_ATTN: bool = False
|
||||||
|
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
|
||||||
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
|
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
|
||||||
LOCAL_RANK: int = 0
|
LOCAL_RANK: int = 0
|
||||||
CUDA_VISIBLE_DEVICES: Optional[str] = None
|
CUDA_VISIBLE_DEVICES: Optional[str] = None
|
||||||
@@ -290,6 +291,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
|
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
|
||||||
("true", "1")),
|
("true", "1")),
|
||||||
|
|
||||||
|
# Use separate prefill and decode kernels for V1 attention instead of
|
||||||
|
# the unified triton kernel.
|
||||||
|
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION":
|
||||||
|
lambda:
|
||||||
|
(os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in
|
||||||
|
("true", "1")),
|
||||||
|
|
||||||
# Force vllm to use a specific flash-attention version (2 or 3), only valid
|
# Force vllm to use a specific flash-attention version (2 or 3), only valid
|
||||||
# when using the flash-attention backend.
|
# when using the flash-attention backend.
|
||||||
"VLLM_FLASH_ATTN_VERSION":
|
"VLLM_FLASH_ATTN_VERSION":
|
||||||
@@ -323,8 +331,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
|
|
||||||
# Whether to log responses from API Server for debugging
|
# Whether to log responses from API Server for debugging
|
||||||
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE":
|
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE":
|
||||||
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False").
|
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"
|
||||||
lower() == "true",
|
).lower() == "true",
|
||||||
|
|
||||||
# S3 access information, used for tensorizer to load model from S3
|
# S3 access information, used for tensorizer to load model from S3
|
||||||
"S3_ACCESS_KEY_ID":
|
"S3_ACCESS_KEY_ID":
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||||
@@ -126,6 +127,8 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
"TritonAttentionImpl")
|
"TritonAttentionImpl")
|
||||||
|
|
||||||
self.fp8_dtype = current_platform.fp8_dtype()
|
self.fp8_dtype = current_platform.fp8_dtype()
|
||||||
|
self.force_prefill_decode_attn = \
|
||||||
|
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -166,9 +169,9 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
# performance to make sure it does not introduce any overhead.
|
# performance to make sure it does not introduce any overhead.
|
||||||
|
|
||||||
num_queries_per_kv = query.shape[1] // key.shape[1]
|
num_queries_per_kv = query.shape[1] // key.shape[1]
|
||||||
use_prefill_decode_attn = (num_queries_per_kv &
|
num_q_is_pow2 = (num_queries_per_kv & (num_queries_per_kv - 1)) == 0
|
||||||
(num_queries_per_kv - 1)) != 0
|
use_prefill_decode_attn = (self.force_prefill_decode_attn
|
||||||
|
or not num_q_is_pow2)
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
if use_prefill_decode_attn:
|
if use_prefill_decode_attn:
|
||||||
|
|||||||
Reference in New Issue
Block a user