[VLM] Reorganize profiling/processing-related code (#11812)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, Union)
|
||||
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -16,13 +17,12 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
|
||||
from vllm.multimodal.parse import ImageSize
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingMixin,
|
||||
BaseLlavaProfilingInfo, LlavaLikeConfig,
|
||||
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
|
||||
LlavaDummyInputsBuilder, LlavaLikeConfig,
|
||||
LlavaMultiModalProjector, init_vision_tower_for_llava)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
|
||||
@@ -65,23 +65,23 @@ class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
|
||||
image_grid_pinpoints: Final[list[list[int]]]
|
||||
|
||||
|
||||
class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
|
||||
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
|
||||
|
||||
def _get_hf_config(self) -> LlavaNextLikeConfig:
|
||||
def get_hf_config(self) -> LlavaNextLikeConfig:
|
||||
return self.ctx.get_hf_config(LlavaNextConfig)
|
||||
|
||||
def _get_hf_processor(self):
|
||||
def get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaNextProcessor)
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
|
||||
def _get_num_image_tokens(
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
vision_encoder_info = self._get_vision_encoder_info()
|
||||
hf_config = self.get_hf_config()
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
|
||||
base_feature_size = self._apply_feature_select_strategy(
|
||||
hf_config.vision_feature_select_strategy,
|
||||
@@ -140,16 +140,13 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
|
||||
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
|
||||
class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
|
||||
|
||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
||||
hf_config = self._get_hf_config()
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
hf_config = self.get_hf_config()
|
||||
|
||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||
for (height, width) in hf_config.image_grid_pinpoints:
|
||||
feat_size = self._get_num_image_tokens(image_width=width,
|
||||
image_height=height)
|
||||
feat_size = self.get_num_image_tokens(image_width=width,
|
||||
image_height=height)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
largest_feature_pinpoint = ImageSize(width=width,
|
||||
@@ -161,11 +158,23 @@ class LlavaNextProfilingInfo(LlavaNextProcessingMixin, BaseLlavaProfilingInfo):
|
||||
return largest_feature_pinpoint
|
||||
|
||||
|
||||
class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
|
||||
BaseLlavaMultiModalProcessor):
|
||||
_I = TypeVar("_I", bound=LlavaNextProcessingInfo)
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return LlavaNextProfilingInfo(self.ctx)
|
||||
|
||||
class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):
|
||||
|
||||
# Copied from BaseMultiModalProcessor
|
||||
@abstractmethod
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LlavaNextMultiModalProcessor(
|
||||
BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
@@ -179,7 +188,9 @@ class LlavaNextMultiModalProcessor(LlavaNextProcessingMixin,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
|
||||
info=LlavaNextProcessingInfo,
|
||||
dummy_inputs=LlavaDummyInputsBuilder)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user