[Bugfix] Comprehensively test and fix LLaVA-NeXT feature size calculation (#11800)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-07 18:25:02 +08:00
committed by GitHub
parent 8082ad7950
commit 8f37be38eb
6 changed files with 257 additions and 93 deletions

View File

@@ -2,7 +2,6 @@ from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
@@ -74,7 +73,7 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaNextProcessor)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
def _get_num_image_tokens(
self,
*,
@@ -111,7 +110,7 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
return unpadded_feature_size + newline_feature_size + base_feature_size
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
def _get_num_unpadded_features(
self,
*,
@@ -121,29 +120,23 @@ class LlavaNextProcessingMixin(BaseLlavaProcessingMixin):
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
# NOTE: Use float32 to remain consistent with HF output
current_height_f = np.float32(npatches * num_patch_height)
current_width_f = np.float32(npatches * num_patch_width)
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
original_width_f = np.float32(original_width)
original_height_f = np.float32(original_height)
aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
original_aspect_ratio = original_width_f / original_height_f
current_aspect_ratio = current_width_f / current_height_f
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width_f / original_width_f
new_height = int(original_height_f * scale_factor)
padding = (current_height_f - new_height) // 2
current_height_f -= 2 * padding
if aspect_ratio > current_aspect_ratio:
new_height = (original_height * current_width) // original_width
padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else:
scale_factor = current_height_f / original_height_f
new_width = int(original_width_f * scale_factor)
padding = (current_width_f - new_width) // 2
current_width_f -= 2 * padding
new_width = (original_width * current_height) // original_height
padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)
unpadded_features = int(current_height_f * current_width_f)
newline_features = int(current_height_f)
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)