[Bugfix] Standardize getting number of image patches/tokens (#34358)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-13 12:47:01 +08:00
committed by GitHub
parent 6afa587d31
commit 372b2e762a
29 changed files with 319 additions and 331 deletions

View File

@@ -1869,12 +1869,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
*,
image_height: int,
image_width: int,
processor: Molmo2ProcessorWrapper | None = None,
processor: Molmo2ProcessorWrapper,
) -> int:
if processor is None:
processor = self.get_hf_processor()
hf_processor = processor.processor # type: ignore
hf_processor = processor.processor
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
# start/end tokens + image patch token + col tokens
@@ -1897,11 +1894,8 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
self,
*,
num_frames: int,
processor: Molmo2ProcessorWrapper | None = None,
processor: Molmo2ProcessorWrapper,
) -> int:
if processor is None:
processor = self.get_hf_processor()
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=True)
# start/end tokens
extra = 2 + resize_nrows * (
@@ -1929,7 +1923,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
width = wr * crop_window_size + total_margin_pixels
feat_size = self.get_num_image_tokens(
image_height=height, image_width=width, processor=processor
image_height=height,
image_width=width,
processor=processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
@@ -1940,8 +1936,15 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
return largest_feature_pinpoint
def _get_max_video_frames(self, max_tokens: int) -> int:
num_tokens_per_frame = self.get_num_video_tokens(num_frames=1)
def _get_max_video_frames(
self,
max_tokens: int,
processor: Molmo2ProcessorWrapper,
) -> int:
num_tokens_per_frame = self.get_num_video_tokens(
num_frames=1,
processor=processor,
)
max_frames = max_tokens // num_tokens_per_frame
return max(max_frames, 1)
@@ -1950,10 +1953,11 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
video_processor = self.get_hf_processor().processor.video_processor
processor = self.get_hf_processor()
video_processor = processor.processor.video_processor
num_frames = video_processor.num_frames
max_videos = mm_counts.get("video", 0)
max_total_frames = self._get_max_video_frames(seq_len)
max_total_frames = self._get_max_video_frames(seq_len, processor)
max_frames_per_video = min(
max_total_frames // max(max_videos, 1),
num_frames,