[VLM] Separate out profiling-related logic (#11746)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-06 16:02:21 +08:00
committed by GitHub
parent 2a622d704a
commit 996357e480
17 changed files with 1036 additions and 739 deletions

View File

@@ -1,4 +1,4 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
@@ -13,6 +13,7 @@ from transformers.models.pixtral import PixtralProcessor
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@@ -25,9 +26,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize)
from vllm.multimodal.processing import (InputProcessingContext,
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingCache,
ProcessorInputs, PromptReplacement)
ProcessingMixin, PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
@@ -37,7 +39,7 @@ from .pixtral import (PixtralHFVisionModel,
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import BaseVisionLanguageMultiModalProcessor
from .vision import get_vision_encoder_info
class LlavaImagePixelInputs(TypedDict):
@@ -94,30 +96,42 @@ class LlavaMultiModalProjector(nn.Module):
class LlavaLikeConfig(Protocol):
vision_config: Final[PretrainedConfig]
image_token_index: Final[int]
vision_feature_select_strategy: Final[str]
vision_feature_layer: Final[Union[int, List[int]]]
vision_feature_layer: Final[Union[int, list[int]]]
class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
class LlavaLikeProcessor(Protocol):
image_token: Final[str]
class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
def _get_hf_config(self) -> LlavaLikeConfig:
return self.ctx.get_hf_config(LlavaConfig)
def _get_vision_encoder_info(self):
return get_vision_encoder_info(self._get_hf_config())
@abstractmethod
def _get_hf_config(self) -> LlavaLikeConfig:
def _get_hf_processor(self) -> LlavaLikeProcessor:
raise NotImplementedError
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_mm_fields_config(
def _get_num_image_tokens(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
vision_encoder_info = self._get_vision_encoder_info()
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def _apply_feature_select_strategy(
@@ -133,31 +147,38 @@ class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
msg = f"Unexpected feature select strategy: {strategy!r}"
raise NotImplementedError(msg)
def _get_max_image_tokens(self) -> int:
hf_config = self._get_hf_config()
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
self._vision_encoder_info.get_max_image_tokens(),
class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self._get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_max_image_tokens(self) -> int:
target_width, target_height = self._get_image_size_with_most_features()
return self._get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
def _get_dummy_image_size(self) -> ImageSize:
image_size = self._vision_encoder_info.get_image_size()
return ImageSize(image_size, image_size)
@abstractmethod
def _get_image_token(self) -> str:
raise NotImplementedError
def _get_dummy_processor_inputs(
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
image_token = self._get_image_token()
target_width, target_height = self._get_dummy_image_size()
processor = self._get_hf_processor()
image_token = processor.image_token
target_width, target_height = self._get_image_size_with_most_features()
mm_data = {
"image":
@@ -172,32 +193,32 @@ class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
)
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
class LlavaProcessingMixin(BaseLlavaProcessingMixin):
def _get_hf_config(self) -> LlavaConfig:
return self.ctx.get_hf_config(LlavaConfig)
def _get_hf_processor(self) -> LlavaProcessor:
def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaProcessor)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
def _get_num_image_tokens(
class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo):
pass
class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
BaseMultiModalProcessor):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_profiling_info(self) -> BaseProfilingInfo:
raise NotImplementedError
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_mm_fields_config(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
self._vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
raise NotImplementedError
def _get_prompt_replacements(
self,
@@ -232,16 +253,37 @@ class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
]
class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
def _get_hf_config(self) -> LlavaConfig:
return self.ctx.get_hf_config(LlavaConfig)
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaProfilingInfo(self.ctx)
def _get_hf_processor(self) -> PixtralProcessor:
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
class PixtralHFProcessingMixin(BaseLlavaProcessingMixin):
def _get_hf_processor(self):
return self.ctx.get_hf_processor(PixtralProcessor)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo):
pass
class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return PixtralHFProfilingInfo(self.ctx)
def _call_hf_processor(
self,
@@ -270,6 +312,16 @@ class PixtralHFMultiModalProcessor(BaseLlavaMultiModalProcessor):
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
@@ -316,7 +368,7 @@ def _build_llava_or_pixtral_hf_processor(
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True,
) -> BaseLlavaMultiModalProcessor:
) -> BaseMultiModalProcessor:
hf_config = ctx.get_hf_config(LlavaConfig)
if isinstance(hf_config.vision_config, PixtralVisionConfig):
@@ -663,16 +715,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaProcessor)
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
hf_config = self.ctx.get_hf_config(LlavaConfig)
hf_config = self._get_hf_config()
image_token_id = hf_config.image_token_index
# Assume that it doesn't depend on the image size