[Attention] Unify mamba and attention backend selection (#23171)
Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
This commit is contained in:
@@ -35,7 +35,8 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import (BatchDescriptor, DPMetadata,
|
||||
set_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
||||
@@ -55,7 +56,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||
get_dtype_size, is_pin_memory_available, round_up,
|
||||
supports_dynamo)
|
||||
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_kv_sharing_fast_prefill_attention_metadata,
|
||||
@@ -2747,11 +2747,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"""
|
||||
assert len(self.attn_groups) == 0, \
|
||||
"Attention backends are already initialized"
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
|
||||
def get_attn_backends_for_layers(
|
||||
layer_names: list[str]
|
||||
) -> dict[type[AttentionBackend], list[str]]:
|
||||
layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
AttentionLayerBase,
|
||||
layer_names)
|
||||
attn_backends = {}
|
||||
attn_backend_layers = defaultdict(list)
|
||||
# Dedupe based on full class name; this is a bit safer than using
|
||||
@@ -2760,7 +2762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# they are cached correctly, there will be different objects per
|
||||
# layer.
|
||||
for layer_name in layer_names:
|
||||
attn_backend = attn_layers[layer_name].get_attn_backend()
|
||||
attn_backend = layers[layer_name].get_attn_backend()
|
||||
key = attn_backend.full_cls_name()
|
||||
attn_backends[key] = attn_backend
|
||||
attn_backend_layers[key].append(layer_name)
|
||||
@@ -2789,20 +2791,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
attn_backends = get_attn_backends_for_layers(
|
||||
kv_cache_group_spec.layer_names)
|
||||
# TODO(lucas): move `get_mamba_attn_backend` into the mamba
|
||||
# layers like above
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
attn_backends = {
|
||||
get_mamba_attn_backend(kv_cache_spec.mamba_type):
|
||||
kv_cache_group_spec.layer_names
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
||||
|
||||
attn_backends = get_attn_backends_for_layers(
|
||||
kv_cache_group_spec.layer_names)
|
||||
self.attn_groups.append(
|
||||
create_attn_groups(attn_backends, kv_cache_spec))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user