[Attention] Support multiple attention metadata builders per kv_cache_spec + proper local attention no hybrid kv cache fix (#21588)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -5,12 +5,12 @@ import enum
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, make_dataclass
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Generic, Optional,
|
||||
TypeVar)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.utils import cdiv
|
||||
|
||||
@@ -20,6 +20,8 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
from vllm.logger import init_logger
|
||||
@@ -532,6 +534,48 @@ def make_local_attention_virtual_batches(
|
||||
)
|
||||
|
||||
|
||||
def subclass_attention_metadata_builder(
|
||||
name_prefix: str,
|
||||
builder_cls: type[AttentionMetadataBuilder[M]],
|
||||
build_preprocess_fn: Callable[[CommonAttentionMetadata],
|
||||
CommonAttentionMetadata],
|
||||
) -> type[AttentionMetadataBuilder[M]]:
|
||||
"""
|
||||
Return a new subclass of `builder_cls` whose .build(...) method
|
||||
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
|
||||
"""
|
||||
name: str = name_prefix + builder_cls.__name__ # type: ignore
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False):
|
||||
return builder_cls.build(self, common_prefix_len,
|
||||
build_preprocess_fn(common_attn_metadata),
|
||||
fast_build)
|
||||
|
||||
Wrapped = type(
|
||||
name,
|
||||
(builder_cls, ), # inherit from the original
|
||||
{
|
||||
"build": build,
|
||||
})
|
||||
return Wrapped # type: ignore
|
||||
|
||||
|
||||
def subclass_attention_backend(
|
||||
name_prefix: str, attention_backend_cls: type[AttentionBackend],
|
||||
builder_cls: type[AttentionMetadataBuilder[M]]
|
||||
) -> type[AttentionBackend]:
|
||||
"""
|
||||
Return a new subclass where `get_builder_cls` returns `builder_cls`.
|
||||
"""
|
||||
name: str = name_prefix + attention_backend_cls.__name__ # type: ignore
|
||||
|
||||
return type(name, (attention_backend_cls, ),
|
||||
{"get_builder_cls": lambda: builder_cls})
|
||||
|
||||
|
||||
def split_decodes_and_prefills(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
decode_threshold: int = 1,
|
||||
|
||||
Reference in New Issue
Block a user