[VLM] Reorganize profiling/processing-related code (#11812)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -17,19 +17,20 @@ from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
||||
VideoEmbeddingItems, VideoProcessorItems)
|
||||
from vllm.multimodal.processing import MultiModalFieldConfig, PromptReplacement
|
||||
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems,
|
||||
VideoProcessorItems)
|
||||
from vllm.multimodal.processing import PromptReplacement
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .llava import BaseLlavaProfilingInfo, init_vision_tower_for_llava
|
||||
from .llava_next import (LlavaNextLikeConfig, LlavaNextMultiModalProcessor,
|
||||
LlavaNextProcessingMixin)
|
||||
from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
|
||||
from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
|
||||
LlavaNextProcessingInfo)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
@@ -89,14 +90,23 @@ class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
|
||||
video_token_index: Final[int]
|
||||
|
||||
|
||||
class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
|
||||
class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||
|
||||
def _get_hf_config(self) -> LlavaOnevisionLikeConfig:
|
||||
def get_hf_config(self) -> LlavaOnevisionLikeConfig:
|
||||
return self.ctx.get_hf_config(LlavaOnevisionConfig)
|
||||
|
||||
def _get_hf_processor(self):
|
||||
def get_hf_processor(self):
|
||||
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self.get_max_image_tokens(),
|
||||
"video": self.get_max_video_tokens(seq_len),
|
||||
}
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
|
||||
# with additional logic afterwards taken from LlavaOnevisionProcessor
|
||||
def _get_num_unpadded_features(
|
||||
@@ -141,16 +151,16 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self._get_hf_config()
|
||||
hf_config = self.get_hf_config()
|
||||
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
|
||||
|
||||
vision_encoder_info = self._get_vision_encoder_info()
|
||||
vision_encoder_info = self.get_vision_encoder_info()
|
||||
patch_grid_length = vision_encoder_info.get_patch_grid_length()
|
||||
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
|
||||
|
||||
return pooled_grid_length * pooled_grid_length
|
||||
|
||||
def _get_num_video_tokens(
|
||||
def get_num_video_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
@@ -164,43 +174,14 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
|
||||
|
||||
return num_frame_tokens * num_frames + 1 # Newline token
|
||||
|
||||
|
||||
class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
||||
BaseLlavaProfilingInfo):
|
||||
|
||||
def _get_image_size_with_most_features(self) -> ImageSize:
|
||||
hf_config = self._get_hf_config()
|
||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||
for (height, width) in hf_config.image_grid_pinpoints:
|
||||
feat_size = self._get_num_image_tokens(image_width=width,
|
||||
image_height=height)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
largest_feature_pinpoint = ImageSize(width=width,
|
||||
height=height)
|
||||
|
||||
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
||||
raise ValueError("Cannot have a largest feature size of 0!")
|
||||
|
||||
return largest_feature_pinpoint
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self._get_max_image_tokens(),
|
||||
"video": self._get_max_video_tokens(seq_len),
|
||||
}
|
||||
|
||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
num_frames = 0
|
||||
|
||||
while True:
|
||||
next_num_frames = num_frames + 1
|
||||
next_max_tokens = self._get_num_video_tokens(
|
||||
next_max_tokens = self.get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=next_num_frames,
|
||||
@@ -213,12 +194,12 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
||||
|
||||
return num_frames
|
||||
|
||||
def _get_dummy_num_frames(self, seq_len: int) -> int:
|
||||
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.limit_per_prompt.get("image", 1)
|
||||
max_videos = mm_config.limit_per_prompt.get("video", 1)
|
||||
|
||||
max_image_tokens = self._get_max_image_tokens() * max_images
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_total_frames = self._get_max_video_frames(seq_len -
|
||||
max_image_tokens)
|
||||
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
||||
@@ -226,15 +207,19 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
||||
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
def _get_max_video_tokens(self, seq_len: int) -> int:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
def get_max_video_tokens(self, seq_len: int) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self._get_num_video_tokens(
|
||||
return self.get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
num_frames=self.get_num_frames_with_most_features(seq_len),
|
||||
)
|
||||
|
||||
|
||||
class LlavaOnevisionDummyInputsBuilder(
|
||||
LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
@@ -243,10 +228,14 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_videos = mm_counts.get("video", 0)
|
||||
|
||||
processor = self._get_hf_processor()
|
||||
processor = self.info.get_hf_processor()
|
||||
image_token = processor.image_token
|
||||
video_token = processor.video_token
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
@@ -257,7 +246,7 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
num_frames=self._get_dummy_num_frames(seq_len),
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
)
|
||||
}
|
||||
@@ -268,11 +257,8 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
||||
)
|
||||
|
||||
|
||||
class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
|
||||
LlavaNextMultiModalProcessor):
|
||||
|
||||
def _get_profiling_info(self) -> BaseProfilingInfo:
|
||||
return LlavaOnevisionProfilingInfo(self.ctx)
|
||||
class LlavaOnevisionMultiModalProcessor(
|
||||
BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
@@ -303,7 +289,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
processor = self._get_hf_processor()
|
||||
processor = self.info.get_hf_processor()
|
||||
video_token = processor.video_token
|
||||
|
||||
# LLaVA-OneVision processor doesn't support multiple videos
|
||||
@@ -345,7 +331,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
|
||||
out_mm_kwargs=out_mm_kwargs,
|
||||
)
|
||||
|
||||
hf_config = self._get_hf_config()
|
||||
hf_config = self.info.get_hf_config()
|
||||
video_token_id = hf_config.video_token_index
|
||||
|
||||
def get_video_replacement(item_idx: int):
|
||||
@@ -356,7 +342,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
|
||||
num_video_tokens = videos.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = videos.get_frame_size(item_idx)
|
||||
num_video_tokens = self._get_num_video_tokens(
|
||||
num_video_tokens = self.info.get_num_video_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
num_frames=videos.get_num_frames(item_idx),
|
||||
@@ -393,7 +379,10 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
LlavaOnevisionMultiModalProcessor,
|
||||
info=LlavaOnevisionProcessingInfo,
|
||||
dummy_inputs=LlavaOnevisionDummyInputsBuilder)
|
||||
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user