[VLM] Separate out profiling-related logic (#11746)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-06 16:02:21 +08:00
committed by GitHub
parent 2a622d704a
commit 996357e480
17 changed files with 1036 additions and 739 deletions

View File

@@ -1,12 +1,8 @@
from abc import ABC, abstractmethod
from typing import Final, Generic, Optional, Protocol, TypeVar
from typing import Final, Generic, Protocol, TypeVar
from transformers import PretrainedConfig
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ProcessingCache)
_C = TypeVar("_C", bound=PretrainedConfig)
@@ -43,12 +39,18 @@ class VisionEncoderInfo(ABC, Generic[_C]):
raise NotImplementedError
def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
class VisionLanguageConfig(Protocol):
vision_config: Final[PretrainedConfig]
def get_vision_encoder_info(
hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
# Avoid circular imports
from .clip import CLIPEncoderInfo, CLIPVisionConfig
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
from .siglip import SiglipEncoderInfo, SiglipVisionConfig
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return CLIPEncoderInfo(vision_config)
if isinstance(vision_config, PixtralVisionConfig):
@@ -58,26 +60,3 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
class VisionLanguageConfig(Protocol):
vision_config: Final[PretrainedConfig]
class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor):
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks)
vision_config = self._get_hf_config().vision_config
self._vision_encoder_info = vision_encoder_info(vision_config)
@abstractmethod
def _get_hf_config(self) -> VisionLanguageConfig:
raise NotImplementedError