[V1] Override mm_counts for dummy data creation (#15703)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -108,7 +108,7 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||
) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self.get_max_image_tokens(),
|
||||
"video": self.get_max_video_tokens(seq_len),
|
||||
"video": self.get_max_video_tokens(seq_len, mm_counts),
|
||||
}
|
||||
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
|
||||
@@ -202,10 +202,13 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||
|
||||
return num_frames
|
||||
|
||||
def get_num_frames_with_most_features(self, seq_len: int) -> int:
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.get_limit_per_prompt("image")
|
||||
max_videos = mm_config.get_limit_per_prompt("video")
|
||||
def get_num_frames_with_most_features(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
max_images = mm_counts.get("image", 0)
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_total_frames = self._get_max_video_frames(seq_len -
|
||||
@@ -215,13 +218,18 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
def get_max_video_tokens(self, seq_len: int) -> int:
|
||||
def get_max_video_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self.get_num_frames_with_most_features(seq_len),
|
||||
num_frames=self.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts),
|
||||
)
|
||||
|
||||
|
||||
@@ -243,7 +251,8 @@ class LlavaOnevisionDummyInputsBuilder(
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
target_num_frames = \
|
||||
self.info.get_num_frames_with_most_features(seq_len)
|
||||
self.info.get_num_frames_with_most_features(seq_len,
|
||||
mm_counts)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
|
||||
Reference in New Issue
Block a user