[VLM] Reorganize profiling/processing-related code (#11812)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-08 18:59:58 +08:00
committed by GitHub
parent f12141170a
commit 2a0596bc48
23 changed files with 833 additions and 760 deletions

View File

@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
import torch
import torch.nn as nn
@@ -25,11 +25,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize)
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingCache,
ProcessingMixin, PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, ProcessingCache,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
@@ -105,34 +105,23 @@ class LlavaLikeProcessor(Protocol):
image_token: Final[str]
class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
class BaseLlavaProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self) -> LlavaLikeConfig:
def get_hf_config(self) -> LlavaLikeConfig:
return self.ctx.get_hf_config(LlavaConfig)
def _get_vision_encoder_info(self):
return get_vision_encoder_info(self._get_hf_config())
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
@abstractmethod
def _get_hf_processor(self) -> LlavaLikeProcessor:
def get_hf_processor(self) -> LlavaLikeProcessor:
raise NotImplementedError
def _get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
vision_encoder_info = self._get_vision_encoder_info()
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _apply_feature_select_strategy(
self,
@@ -147,28 +136,42 @@ class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
msg = f"Unexpected feature select strategy: {strategy!r}"
raise NotImplementedError(msg)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()
class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
return self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
vision_encoder_info.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
),
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self._get_vision_encoder_info()
def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_max_image_tokens(self) -> int:
target_width, target_height = self._get_image_size_with_most_features()
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_image_tokens(
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs(
self,
seq_len: int,
@@ -176,9 +179,10 @@ class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
image_token = processor.image_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
@@ -193,23 +197,13 @@ class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
)
class LlavaProcessingMixin(BaseLlavaProcessingMixin):
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaProcessor)
class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo):
pass
class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
BaseMultiModalProcessor):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_profiling_info(self) -> BaseProfilingInfo:
raise NotImplementedError
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
# Copied from BaseMultiModalProcessor
@abstractmethod
@@ -226,7 +220,7 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
def get_replacement(item_idx: int):
@@ -237,7 +231,7 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
num_image_tokens = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
num_image_tokens = self._get_num_image_tokens(
num_image_tokens = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
)
@@ -253,10 +247,8 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
]
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaProfilingInfo(self.ctx)
class LlavaMultiModalProcessor(
BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
def _get_mm_fields_config(
self,
@@ -269,21 +261,14 @@ class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
)
class PixtralHFProcessingMixin(BaseLlavaProcessingMixin):
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(PixtralProcessor)
class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo):
pass
class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return PixtralHFProfilingInfo(self.ctx)
class PixtralHFMultiModalProcessor(
BaseMultiModalProcessor[PixtralHFProcessingInfo]):
def _call_hf_processor(
self,
@@ -328,10 +313,10 @@ class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
@@ -363,26 +348,40 @@ class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
]
def _build_llava_or_pixtral_hf_info(
ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
hf_config = ctx.get_hf_config(LlavaConfig)
if isinstance(hf_config.vision_config, PixtralVisionConfig):
return PixtralHFProcessingInfo(ctx)
return LlavaProcessingInfo(ctx)
def _build_llava_or_pixtral_hf_processor(
ctx: InputProcessingContext,
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True,
) -> BaseMultiModalProcessor:
hf_config = ctx.get_hf_config(LlavaConfig)
if isinstance(hf_config.vision_config, PixtralVisionConfig):
if isinstance(info, PixtralHFProcessingInfo):
return PixtralHFMultiModalProcessor(
ctx,
info,
dummy_inputs, # type: ignore
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
return LlavaMultiModalProcessor(
ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
if isinstance(info, LlavaProcessingInfo):
return LlavaMultiModalProcessor(
info,
dummy_inputs, # type: ignore
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
raise NotImplementedError(type(info))
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
@@ -460,7 +459,9 @@ def init_vision_tower_for_llava(
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor)
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
info=_build_llava_or_pixtral_hf_info,
dummy_inputs=LlavaDummyInputsBuilder)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
@@ -727,11 +728,11 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
# Assume that it doesn't depend on the image size
num_image_tokens = self._get_num_image_tokens(
num_image_tokens = self.info.get_num_image_tokens(
image_width=-1,
image_height=-1,
)
@@ -796,6 +797,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
info=LlavaProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass