[VLM] Separate out profiling-related logic (#11746)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,12 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar
|
||||
from typing import Final, Generic, Protocol, TypeVar
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext,
|
||||
ProcessingCache)
|
||||
|
||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||
|
||||
|
||||
@@ -43,12 +39,18 @@ class VisionEncoderInfo(ABC, Generic[_C]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
|
||||
class VisionLanguageConfig(Protocol):
|
||||
vision_config: Final[PretrainedConfig]
|
||||
|
||||
|
||||
def get_vision_encoder_info(
|
||||
hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
|
||||
# Avoid circular imports
|
||||
from .clip import CLIPEncoderInfo, CLIPVisionConfig
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
|
||||
from .siglip import SiglipEncoderInfo, SiglipVisionConfig
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPEncoderInfo(vision_config)
|
||||
if isinstance(vision_config, PixtralVisionConfig):
|
||||
@@ -58,26 +60,3 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo:
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
class VisionLanguageConfig(Protocol):
|
||||
vision_config: Final[PretrainedConfig]
|
||||
|
||||
|
||||
class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
super().__init__(ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks)
|
||||
|
||||
vision_config = self._get_hf_config().vision_config
|
||||
self._vision_encoder_info = vision_encoder_info(vision_config)
|
||||
|
||||
@abstractmethod
|
||||
def _get_hf_config(self) -> VisionLanguageConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user