[Model] Support SDPA attention for Molmo vision backbone (#9410)

This commit is contained in:
Isotr0py
2024-10-16 19:52:01 +08:00
committed by GitHub
parent 59230ef32b
commit cf1d62a644
3 changed files with 61 additions and 78 deletions

View File

@@ -8,15 +8,22 @@ import torch.nn as nn
from torch.func import functional_call
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
from vllm.utils import is_cpu, is_pin_memory_available
logger = init_logger(__name__)
WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
@@ -487,3 +494,29 @@ class LLMWrapper(nn.Module):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
llm = super().__getattr__(self.model_name)
return llm(*args, **kwargs)
def get_vit_attn_backend() -> _Backend:
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80)
if device_available:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
logger.warning(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")
selected_backend = _Backend.XFORMERS
elif is_cpu():
selected_backend = _Backend.TORCH_SDPA
else:
selected_backend = _Backend.XFORMERS
return selected_backend