[VLM] Reorganize profiling/processing-related code (#11812)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-08 18:59:58 +08:00
committed by GitHub
parent f12141170a
commit 2a0596bc48
23 changed files with 833 additions and 760 deletions

View File

@@ -1,6 +1,7 @@
from abc import abstractmethod
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
import torch
import torch.nn as nn
@@ -16,13 +17,12 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.parse import ImageSize
from vllm.multimodal.profiling import BaseProfilingInfo
from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingMixin,
BaseLlavaProfilingInfo, LlavaLikeConfig,
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
LlavaDummyInputsBuilder, LlavaLikeConfig,
LlavaMultiModalProjector, init_vision_tower_for_llava)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
@@ -65,23 +65,23 @@ class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
image_grid_pinpoints: Final[list[list[int]]]
class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
def _get_hf_config(self) -> LlavaNextLikeConfig:
def get_hf_config(self) -> LlavaNextLikeConfig:
return self.ctx.get_hf_config(LlavaNextConfig)
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextProcessor)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
def _get_num_image_tokens(
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
vision_encoder_info = self._get_vision_encoder_info()
hf_config = self.get_hf_config()
vision_encoder_info = self.get_vision_encoder_info()
base_feature_size = self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy,
@@ -140,16 +140,13 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
return (unpadded_features, newline_features)
class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
def _get_image_size_with_most_features(self) -> ImageSize:
hf_config = self._get_hf_config()
def get_image_size_with_most_features(self) -> ImageSize:
hf_config = self.get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = self._get_num_image_tokens(image_width=width,
image_height=height)
feat_size = self.get_num_image_tokens(image_width=width,
image_height=height)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
@@ -161,11 +158,23 @@ class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
return largest_feature_pinpoint
class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
BaseLlavaMultiModalProcessor):
_I = TypeVar("_I", bound=LlavaNextProcessingInfo)
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaNextProfilingInfo(self.ctx)
class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):
# Copied from BaseMultiModalProcessor
@abstractmethod
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
raise NotImplementedError
class LlavaNextMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
def _get_mm_fields_config(
self,
@@ -179,7 +188,9 @@ class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
)
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
info=LlavaNextProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):