[Kernel] [V1] Fix performance regression for triton unified attention (#18161)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -56,11 +56,11 @@ def kernel_unified_attention_2d(
|
|||||||
stride_k_cache_0: tl.int64, # int
|
stride_k_cache_0: tl.int64, # int
|
||||||
stride_k_cache_1: tl.int64, # int
|
stride_k_cache_1: tl.int64, # int
|
||||||
stride_k_cache_2: tl.int64, # int
|
stride_k_cache_2: tl.int64, # int
|
||||||
stride_k_cache_3: tl.int64, # int
|
stride_k_cache_3: tl.constexpr, # int
|
||||||
stride_v_cache_0: tl.int64, # int
|
stride_v_cache_0: tl.int64, # int
|
||||||
stride_v_cache_1: tl.int64, # int
|
stride_v_cache_1: tl.int64, # int
|
||||||
stride_v_cache_2: tl.int64, # int
|
stride_v_cache_2: tl.int64, # int
|
||||||
stride_v_cache_3: tl.int64, # int
|
stride_v_cache_3: tl.constexpr, # int
|
||||||
query_start_len_ptr, # [num_seqs+1]
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
BLOCK_Q: tl.constexpr, # int
|
BLOCK_Q: tl.constexpr, # int
|
||||||
num_seqs: tl.int32,
|
num_seqs: tl.int32,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -12,10 +12,23 @@ from vllm.logger import init_logger
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.flash_attn import (
|
from vllm.v1.attention.backends.flash_attn import (
|
||||||
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
||||||
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
|
||||||
|
|
||||||
|
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||||
|
block_table: BlockTable):
|
||||||
|
super().__init__(runner, kv_cache_spec, block_table)
|
||||||
|
self.aot_schedule = False
|
||||||
|
|
||||||
|
|
||||||
class TritonAttentionBackend(AttentionBackend):
|
class TritonAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
@@ -52,8 +65,8 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
|
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
|
||||||
return FlashAttentionMetadataBuilder
|
return TritonAttentionMetadataBuilder
|
||||||
|
|
||||||
|
|
||||||
class TritonAttentionImpl(AttentionImpl):
|
class TritonAttentionImpl(AttentionImpl):
|
||||||
|
|||||||
Reference in New Issue
Block a user