[ROCm][V1][Bugfix] Add get_builder_cls method to the ROCmAttentionBackend class (#14065)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -49,6 +50,10 @@ class ROCmAttentionBackend(AttentionBackend):
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
|
||||
return FlashAttentionMetadataBuilder
|
||||
|
||||
|
||||
class ROCmAttentionImpl(AttentionImpl):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user