[Bugfix] Standardize getting number of image patches/tokens (#34358)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1869,12 +1869,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
|
||||
*,
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
processor: Molmo2ProcessorWrapper | None = None,
|
||||
processor: Molmo2ProcessorWrapper,
|
||||
) -> int:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
hf_processor = processor.processor # type: ignore
|
||||
hf_processor = processor.processor
|
||||
|
||||
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
|
||||
# start/end tokens + image patch token + col tokens
|
||||
@@ -1897,11 +1894,8 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
|
||||
self,
|
||||
*,
|
||||
num_frames: int,
|
||||
processor: Molmo2ProcessorWrapper | None = None,
|
||||
processor: Molmo2ProcessorWrapper,
|
||||
) -> int:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=True)
|
||||
# start/end tokens
|
||||
extra = 2 + resize_nrows * (
|
||||
@@ -1929,7 +1923,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
|
||||
width = wr * crop_window_size + total_margin_pixels
|
||||
|
||||
feat_size = self.get_num_image_tokens(
|
||||
image_height=height, image_width=width, processor=processor
|
||||
image_height=height,
|
||||
image_width=width,
|
||||
processor=processor,
|
||||
)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
@@ -1940,8 +1936,15 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
return largest_feature_pinpoint
|
||||
|
||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||
num_tokens_per_frame = self.get_num_video_tokens(num_frames=1)
|
||||
def _get_max_video_frames(
|
||||
self,
|
||||
max_tokens: int,
|
||||
processor: Molmo2ProcessorWrapper,
|
||||
) -> int:
|
||||
num_tokens_per_frame = self.get_num_video_tokens(
|
||||
num_frames=1,
|
||||
processor=processor,
|
||||
)
|
||||
max_frames = max_tokens // num_tokens_per_frame
|
||||
return max(max_frames, 1)
|
||||
|
||||
@@ -1950,10 +1953,11 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
video_processor = self.get_hf_processor().processor.video_processor
|
||||
processor = self.get_hf_processor()
|
||||
video_processor = processor.processor.video_processor
|
||||
num_frames = video_processor.num_frames
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
max_total_frames = self._get_max_video_frames(seq_len)
|
||||
max_total_frames = self._get_max_video_frames(seq_len, processor)
|
||||
max_frames_per_video = min(
|
||||
max_total_frames // max(max_videos, 1),
|
||||
num_frames,
|
||||
|
||||
Reference in New Issue
Block a user