[V1] Override mm_counts for dummy data creation (#15703)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-30 18:20:42 +08:00
committed by GitHub
parent 7fd8c0f85c
commit 803d5c35f3
9 changed files with 114 additions and 93 deletions

View File

@@ -108,7 +108,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
"video": self.get_max_video_tokens(seq_len, mm_counts),
}
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
@@ -202,10 +202,13 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
return num_frames
def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video")
def get_num_frames_with_most_features(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
max_images = mm_counts.get("image", 0)
max_videos = mm_counts.get("video", 0)
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
@@ -215,13 +218,18 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
return max(max_frames_per_video, 1)
def get_max_video_tokens(self, seq_len: int) -> int:
def get_max_video_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len),
num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
)
@@ -243,7 +251,8 @@ class LlavaOnevisionDummyInputsBuilder(
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len)
self.info.get_num_frames_with_most_features(seq_len,
mm_counts)
mm_data = {
"image":