[VLM] Reorganize profiling/processing-related code (#11812)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, Union)
|
||||
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -25,11 +25,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize)
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessingCache,
|
||||
ProcessingMixin, PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
@@ -105,34 +105,23 @@ class LlavaLikeProcessor(Protocol):
|
||||
image_token: Final[str]
|
||||
|
||||
|
||||
class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
|
||||
class BaseLlavaProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def _get_hf_config(self) -> LlavaLikeConfig:
|
||||
def get_hf_config(self) -> LlavaLikeConfig:
|
||||
return self.ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
def _get_vision_encoder_info(self):
|
||||
return get_vision_encoder_info(self._get_hf_config())
|
||||
def get_vision_encoder_info(self):
|
||||
return get_vision_encoder_info(self.get_hf_config())
|
||||
|
||||
@abstractmethod
|
||||
def _get_hf_processor(self) -> LlavaLikeProcessor:
|
||||
def get_hf_processor(self) -> LlavaLikeProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
vision_encoder_info = self._get_vision_encoder_info()
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self.get_max_image_tokens()}
|
||||
|
||||
def _apply_feature_select_strategy(
|
||||
self,
|
||||
@@ -147,28 +136,42 @@ class BaseLlavaProcessingMixin(ProcessingMixin, ABC):
|
||||
msg = f"Unexpected feature select strategy: {strategy!r}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
|
||||
class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
|
||||
return self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
vision_encoder_info.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {"image": self._get_max_image_tokens()}
|
||||
|
||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
||||
vision_encoder_info = self._get_vision_encoder_info()
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
width = height = vision_encoder_info.get_image_size()
|
||||
return ImageSize(width=width, height=height)
|
||||
|
||||
def _get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self._get_num_image_tokens(
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
)
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseLlavaProcessingInfo)
|
||||
|
||||
|
||||
class LlavaDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
@@ -176,9 +179,10 @@ class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
processor = self._get_hf_processor()
|
||||
processor = self.info.get_hf_processor()
|
||||
image_token = processor.image_token
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
@@ -193,23 +197,13 @@ class BaseLlavaProfilingInfo(BaseLlavaProcessingMixin, BaseProfilingInfo):
|
||||
)
|
||||
|
||||
|
||||
class LlavaProcessingMixin(BaseLlavaProcessingMixin):
|
||||
class LlavaProcessingInfo(BaseLlavaProcessingInfo):
|
||||
|
||||
def _get_hf_processor(self):
|
||||
def get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaProcessor)
|
||||
|
||||
|
||||
class LlavaProfilingInfo(LlavaProcessingMixin, BaseLlavaProfilingInfo):
|
||||
pass
|
||||
|
||||
|
||||
class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
# Copied from BaseMultiModalProcessor
|
||||
@abstractmethod
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
raise NotImplementedError
|
||||
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
# Copied from BaseMultiModalProcessor
|
||||
@abstractmethod
|
||||
@@ -226,7 +220,7 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self._get_hf_config()
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
@@ -237,7 +231,7 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
|
||||
num_image_tokens = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
num_image_tokens = self._get_num_image_tokens(
|
||||
num_image_tokens = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
@@ -253,10 +247,8 @@ class BaseLlavaMultiModalProcessor(LlavaProcessingMixin,
|
||||
]
|
||||
|
||||
|
||||
class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return LlavaProfilingInfo(self.ctx)
|
||||
class LlavaMultiModalProcessor(
|
||||
BaseLlavaMultiModalProcessor[LlavaProcessingInfo]):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
@@ -269,21 +261,14 @@ class LlavaMultiModalProcessor(BaseLlavaMultiModalProcessor):
|
||||
)
|
||||
|
||||
|
||||
class PixtralHFProcessingMixin(BaseLlavaProcessingMixin):
|
||||
class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):
|
||||
|
||||
def _get_hf_processor(self):
|
||||
def get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(PixtralProcessor)
|
||||
|
||||
|
||||
class PixtralHFProfilingInfo(PixtralHFProcessingMixin, BaseLlavaProfilingInfo):
|
||||
pass
|
||||
|
||||
|
||||
class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
|
||||
BaseMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return PixtralHFProfilingInfo(self.ctx)
|
||||
class PixtralHFMultiModalProcessor(
|
||||
BaseMultiModalProcessor[PixtralHFProcessingInfo]):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@@ -328,10 +313,10 @@ class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self._get_hf_config()
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
processor = self._get_hf_processor()
|
||||
processor = self.info.get_hf_processor()
|
||||
image_token = processor.image_token
|
||||
image_break_token = processor.image_break_token
|
||||
image_end_token = processor.image_end_token
|
||||
@@ -363,26 +348,40 @@ class PixtralHFMultiModalProcessor(PixtralHFProcessingMixin,
|
||||
]
|
||||
|
||||
|
||||
def _build_llava_or_pixtral_hf_info(
|
||||
ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
||||
return PixtralHFProcessingInfo(ctx)
|
||||
|
||||
return LlavaProcessingInfo(ctx)
|
||||
|
||||
|
||||
def _build_llava_or_pixtral_hf_processor(
|
||||
ctx: InputProcessingContext,
|
||||
info: _I,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True,
|
||||
) -> BaseMultiModalProcessor:
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
|
||||
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
||||
if isinstance(info, PixtralHFProcessingInfo):
|
||||
return PixtralHFMultiModalProcessor(
|
||||
ctx,
|
||||
info,
|
||||
dummy_inputs, # type: ignore
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
return LlavaMultiModalProcessor(
|
||||
ctx,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
if isinstance(info, LlavaProcessingInfo):
|
||||
return LlavaMultiModalProcessor(
|
||||
info,
|
||||
dummy_inputs, # type: ignore
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
raise NotImplementedError(type(info))
|
||||
|
||||
|
||||
def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
|
||||
@@ -460,7 +459,9 @@ def init_vision_tower_for_llava(
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(_build_llava_or_pixtral_hf_processor,
|
||||
info=_build_llava_or_pixtral_hf_info,
|
||||
dummy_inputs=LlavaDummyInputsBuilder)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
@@ -727,11 +728,11 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
hf_config = self._get_hf_config()
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
# Assume that it doesn't depend on the image size
|
||||
num_image_tokens = self._get_num_image_tokens(
|
||||
num_image_tokens = self.info.get_num_image_tokens(
|
||||
image_width=-1,
|
||||
image_height=-1,
|
||||
)
|
||||
@@ -796,6 +797,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
# To use this model, please use
|
||||
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
||||
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
|
||||
info=LlavaProcessingInfo,
|
||||
dummy_inputs=LlavaDummyInputsBuilder)
|
||||
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user