[VLM] Separate out profiling-related logic (#11746)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, Union)
|
||||
@@ -13,6 +13,7 @@ from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@@ -25,9 +26,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize)
|
||||
from vllm.multimodal.processing import (InputProcessingContext,
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessingCache,
|
||||
ProcessorInputs, PromptReplacement)
|
||||
ProcessingMixin, PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
@@ -37,7 +39,7 @@ from .pixtral import (PixtralHFVisionModel,
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import BaseVisionLanguageMultiModalProcessor
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
@@ -94,30 +96,42 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
|
||||
class LlavaLikeConfig(Protocol):
|
||||
vision_config: Final[PretrainedConfig]
|
||||
image_token_index: Final[int]
|
||||
vision_feature_select_strategy: Final[str]
|
||||
vision_feature_layer: Final[Union[int, List[int]]]
|
||||
vision_feature_layer: Final[Union[int, list[int]]]
|
||||
|
||||
|
||||
class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
class LlavaLikeProcessor(Protocol):
|
||||
image_token: Final[str]
|
||||
|
||||
|
||||
class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
|
||||
|
||||
def _get_hf_config(self) -> LlavaLikeConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
def _get_vision_encoder_info(self):
|
||||
return get_vision_encoder_info(self._get_hf_config())
|
||||
|
||||
@abstractmethod
|
||||
def _get_hf_config(self) -> LlavaLikeConfig:
|
||||
def _get_hf_processor(self) -> LlavaLikeProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_max_image_tokens()}
|
||||
|
||||
def _get_mm_fields_config(
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
vision_encoder_info = self._get_vision_encoder_info()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
|
||||
def _apply_feature_select_strategy(
|
||||
@@ -133,31 +147,38 @@ class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
msg = f"Unexpected feature select strategy: {strategy!r}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_max_image_tokens(),
|
||||
class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_max_image_tokens()}
|
||||
|
||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
||||
vision_encoder_info = self._get_vision_encoder_info()
|
||||
width = height = vision_encoder_info.get_image_size()
|
||||
return ImageSize(width=width, height=height)
|
||||
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
return self._get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
|
||||
def _get_dummy_image_size(self) -> ImageSize:
|
||||
image_size = self._vision_encoder_info.get_image_size()
|
||||
return ImageSize(image_size, image_size)
|
||||
|
||||
@abstractmethod
|
||||
def _get_image_token(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_dummy_processor_inputs(
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
image_token = self._get_image_token()
|
||||
target_width, target_height = self._get_dummy_image_size()
|
||||
processor = self._get_hf_processor()
|
||||
image_token = processor.image_token
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
@@ -172,32 +193,32 @@ class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
class LlavaProcessingMixin(BaseLlavaProcessingMixin):
|
||||
|
||||
def _get_hf_config(self) -> LlavaConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
def _get_hf_processor(self) -> LlavaProcessor:
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaProcessor)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
def _get_num_image_tokens(
|
||||
class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo):
|
||||
pass
|
||||
|
||||
|
||||
class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
# Copied from BaseMultiModalProcessor
|
||||
@abstractmethod
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
raise NotImplementedError
|
||||
|
||||
# Copied from BaseMultiModalProcessor
|
||||
@abstractmethod
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
self._vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
@@ -232,16 +253,37 @@ class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
]
|
||||
|
||||
|
||||
class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_config(self) -> LlavaConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return LlavaProfilingInfo(self.ctx)
|
||||
|
||||
def _get_hf_processor(self) -> PixtralProcessor:
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
|
||||
class PixtralHFProcessingMixin(BaseLlavaProcessingMixin):
|
||||
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(PixtralProcessor)
|
||||
|
||||
def _get_image_token(self) -> str:
|
||||
return self._get_hf_processor().image_token
|
||||
|
||||
class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo):
|
||||
pass
|
||||
|
||||
|
||||
class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return PixtralHFProfilingInfo(self.ctx)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@@ -270,6 +312,16 @@ class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
@@ -316,7 +368,7 @@ def _build_llava_or_pixtral_hf_processor(
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True,
|
||||
) -> BaseLlavaMultiModalProcessor:
|
||||
) -> BaseMultiModalProcessor:
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
||||
@@ -663,16 +715,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaProcessor)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
hf_config = self._get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
# Assume that it doesn't depend on the image size
|
||||
|
||||
Reference in New Issue
Block a user