[VLM] Reorganize profiling/processing-related code (#11812)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-08 18:59:58 +08:00
committed by GitHub
parent f12141170a
commit 2a0596bc48
23 changed files with 833 additions and 760 deletions

View File

@@ -17,12 +17,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageSize, VideoEmbeddingItems,
VideoProcessorItems)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessingMixin,
PromptReplacement)
from vllm.multimodal.profiling import BaseProfilingInfo, ProcessorInputs
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@@ -47,33 +46,52 @@ class LlavaNextVideoPixelInputs(TypedDict):
"""
class LlavaNextVideoProcessingMixin(ProcessingMixin):
class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def _get_hf_config(self):
def get_hf_config(self):
return self.ctx.get_hf_config(LlavaNextVideoConfig)
def _get_vision_encoder_info(self):
return get_vision_encoder_info(self._get_hf_config())
def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config())
def _get_hf_processor(self):
def get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_video_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len),
)
return {"video": max_video_tokens}
def get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self.get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_num_frame_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
hf_config = self.get_hf_config()
spatial_pool_stride = hf_config.spatial_pool_stride
vision_encoder_info = self._get_vision_encoder_info()
vision_encoder_info = self.get_vision_encoder_info()
patch_grid_length = vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length
def _get_num_video_tokens(
def get_num_video_tokens(
self,
*,
image_width: int,
@@ -87,37 +105,14 @@ class LlavaNextVideoProcessingMixin(ProcessingMixin):
return num_frame_tokens * num_frames
class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
BaseProfilingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_size_with_most_features()
max_video_tokens = self._get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
)
return {"video": max_video_tokens}
def _get_image_size_with_most_features(self) -> ImageSize:
vision_encoder_info = self._get_vision_encoder_info()
width = height = vision_encoder_info.get_image_size()
return ImageSize(width=width, height=height)
def _get_max_video_frames(self, max_tokens: int) -> int:
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
while True:
next_num_frames = num_frames + 1
next_max_tokens = self._get_num_video_tokens(
next_max_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
@@ -130,7 +125,7 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
return num_frames
def _get_dummy_num_frames(self, seq_len: int) -> int:
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_videos = mm_config.limit_per_prompt.get("video", 1)
@@ -138,6 +133,10 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
return max(max_total_frames // max(max_videos, 1), 1)
class LlavaNextVideoDummyInputsBuilder(
BaseDummyInputsBuilder[LlavaNextVideoProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
@@ -145,16 +144,20 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
) -> ProcessorInputs:
num_videos = mm_counts.get("video", 0)
processor = self._get_hf_processor()
processor = self.info.get_hf_processor()
video_token = processor.video_token
target_width, target_height = self._get_image_size_with_most_features()
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len)
mm_data = {
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_frames=target_num_frames,
num_videos=num_videos,
)
}
@@ -165,11 +168,8 @@ class LlavaNextVideoProfilingInfo(LlavaNextVideoProcessingMixin,
)
class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
BaseMultiModalProcessor):
def _get_profiling_info(self) -> BaseProfilingInfo:
return LlavaNextVideoProfilingInfo(self.ctx)
class LlavaNextVideoMultiModalProcessor(
BaseMultiModalProcessor[LlavaNextVideoProcessingInfo]):
def _get_mm_fields_config(
self,
@@ -184,7 +184,7 @@ class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_index
def get_replacement(item_idx: int):
@@ -195,7 +195,7 @@ class LlavaNextVideoMultiModalProcessor(LlavaNextVideoProcessingMixin,
num_video_tokens = videos.get_feature_size(item_idx)
else:
image_size = videos.get_frame_size(item_idx)
num_video_tokens = self._get_num_video_tokens(
num_video_tokens = self.info.get_num_video_tokens(
image_width=image_size.width,
image_height=image_size.height,
num_frames=videos.get_num_frames(item_idx),
@@ -269,7 +269,11 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_processor(
LlavaNextVideoMultiModalProcessor,
info=LlavaNextVideoProcessingInfo,
dummy_inputs=LlavaNextVideoDummyInputsBuilder,
)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):