[Bugfix] Comprehensively test and fix LLaVA-NeXT feature size calculation (#11800)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-07 18:25:02 +08:00
committed by GitHub
parent 8082ad7950
commit 8f37be38eb
6 changed files with 257 additions and 93 deletions

View File

@@ -3,7 +3,6 @@ from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
from transformers import (BatchFeature, LlavaOnevisionConfig,
@@ -98,6 +97,8 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
# with additional logic afterwards taken from LlavaOnevisionProcessor
def _get_num_unpadded_features(
self,
*,
@@ -107,35 +108,28 @@ class LlavaOnevisionProcessingMixin(LlavaNextProcessingMixin):
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
# NOTE: Use float32 to remain consistent with HF output
current_height_f = np.float32(npatches * num_patch_height)
current_width_f = np.float32(npatches * num_patch_width)
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
original_width_f = np.float32(original_width)
original_height_f = np.float32(original_height)
aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
original_aspect_ratio = original_width_f / original_height_f
current_aspect_ratio = current_width_f / current_height_f
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width_f / original_width_f
new_height = int(original_height_f * scale_factor)
padding = (current_height_f - new_height) // 2
current_height_f -= 2 * padding
if aspect_ratio > current_aspect_ratio:
new_height = (original_height * current_width) // original_width
padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else:
scale_factor = current_height_f / original_height_f
new_width = int(original_width_f * scale_factor)
padding = (current_width_f - new_width) // 2
current_width_f -= 2 * padding
new_width = (original_width * current_height) // original_height
padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)
unpadded_features = int(current_height_f * current_width_f)
newline_features = int(current_height_f)
unpadded_features = current_height * current_width
newline_features = current_height
ratio = math.sqrt(current_height_f * current_width_f /
(9 * npatches**2))
ratio = math.sqrt(current_height * current_width / (9 * npatches**2))
if ratio > 1.1:
height_factor = int(current_height_f // ratio)
width_factor = int(current_width_f // ratio)
height_factor = int(current_height // ratio)
width_factor = int(current_width // ratio)
unpadded_features = height_factor * width_factor
newline_features = height_factor