[VLM] Reorganize profiling/processing-related code (#11812)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-08 18:59:58 +08:00
committed by GitHub
parent f12141170a
commit 2a0596bc48
23 changed files with 833 additions and 760 deletions

View File

@@ -17,19 +17,20 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import MultiModalFieldConfig, PromptReplacement
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems,
VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import BaseLlavaProfilingInfo, init_vision_tower_for_llava
from .llava_next import (LlavaNextLikeConfig, LlavaNextMultiModalProcessor,
LlavaNextProcessingMixin)
from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
LlavaNextProcessingInfo)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
@@ -89,14 +90,23 @@ class LlavaOnevisionLikeConfig(LlavaNextLikeConfig, Protocol):
video_token_index: Final[int]
class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def _get_hf_config(self) -> LlavaOnevisionLikeConfig:
def get_hf_config(self) -> LlavaOnevisionLikeConfig:
return self.ctx.get_hf_config(LlavaOnevisionConfig)
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
}
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
# with additional logic afterwards taken from LlavaOnevisionProcessor
def _get_num_unpadded_features(
@@ -141,16 +151,16 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
hf_config = self.get_hf_config()
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
vision_encoder_info = self._get_vision_encoder_info()
vision_encoder_info = self.get_vision_encoder_info()
patch_grid_length = vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length
def _get_num_video_tokens(
def get_num_video_tokens(
self,
*,
image_width: int,
@@ -164,43 +174,14 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
return num_frame_tokens * num_frames + 1 # Newline token
class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
BaseLlavaProfilingInfo):
def _get_image_size_with_most_features(self) -> ImageSize:
hf_config = self._get_hf_config()
largest_feature_size, largest_feature_pinpoint = 0, None
for (height, width) in hf_config.image_grid_pinpoints:
feat_size = self._get_num_image_tokens(image_width=width,
image_height=height)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {
"image": self._get_max_image_tokens(),
"video": self._get_max_video_tokens(seq_len),
}
def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
while True:
next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens(
next_max_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
@@ -213,12 +194,12 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
@@ -226,15 +207,19 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
return max(max_frames_per_video, 1)
def _get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
def get_max_video_tokens(self, seq_len: int) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self._get_num_video_tokens(
return self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=self.get_num_frames_with_most_features(seq_len),
)
class LlavaOnevisionDummyInputsBuilder(
LlavaDummyInputsBuilder[LlavaOnevisionProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
@@ -243,10 +228,14 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
image_token = processor.image_token
video_token = processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len)
mm_data = {
"image":
@@ -257,7 +246,7 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=target_num_frames,
num_videos=num_videos,
)
}
@@ -268,11 +257,8 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
)
class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
LlavaNextMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaOnevisionProfilingInfo(self.ctx)
class LlavaOnevisionMultiModalProcessor(
BaseLlavaNextMultiModalProcessor[LlavaOnevisionProcessingInfo]):
def _get_mm_fields_config(
self,
@@ -303,7 +289,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
mm_kwargs=mm_kwargs,
)
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
video_token = processor.video_token
# LLaVA-OneVision processor doesn't support multiple videos
@@ -345,7 +331,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
out_mm_kwargs=out_mm_kwargs,
)
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_index
def get_video_replacement(item_idx: int):
@@ -356,7 +342,7 @@ class LlavaOnevisionMultiModalProcessor(LlavaOnevisionProcessingMixin,
num_video_tokens = videos.get_feature_size(item_idx)
else:
image_size = videos.get_frame_size(item_idx)
num_video_tokens = self._get_num_video_tokens(
num_video_tokens = self.info.get_num_video_tokens(
image_width=image_size.width,
image_height=image_size.height,
num_frames=videos.get_num_frames(item_idx),
@@ -393,7 +379,10 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(
LlavaOnevisionMultiModalProcessor,
info=LlavaOnevisionProcessingInfo,
dummy_inputs=LlavaOnevisionDummyInputsBuilder)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):