[V1] Enable Triton(ROCm) Attention backend for Nvidia GPUs (#14071)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Attention layer with PagedAttention on rocm"""
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@@ -16,7 +16,7 @@ from vllm.v1.attention.backends.flash_attn import (
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ROCmAttentionBackend(AttentionBackend):
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@@ -26,11 +26,11 @@ class ROCmAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_ATTN_VLLM_V1"
|
||||
return "TRITON_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["ROCmAttentionImpl"]:
|
||||
return ROCmAttentionImpl
|
||||
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
||||
return TritonAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
@@ -56,7 +56,7 @@ class ROCmAttentionBackend(AttentionBackend):
|
||||
return FlashAttentionMetadataBuilder
|
||||
|
||||
|
||||
class ROCmAttentionImpl(AttentionImpl):
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -73,7 +73,7 @@ class ROCmAttentionImpl(AttentionImpl):
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"ROCmAttention does not support block-sparse attention.")
|
||||
"TritonAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
@@ -90,17 +90,17 @@ class ROCmAttentionImpl(AttentionImpl):
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes()
|
||||
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by ROCmAttention. "
|
||||
f"Head size {head_size} is not supported by TritonAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"ROCmAttentionImpl")
|
||||
"TritonAttentionImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
Reference in New Issue
Block a user