Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -23,31 +23,48 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.models.intern_vit import (
InternVisionModel,
InternVisionPatchModel,
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import set_default_torch_num_threads
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'
IMG_START = "<img>"
IMG_END = "</img>"
IMG_CONTEXT = "<IMG_CONTEXT>"
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
@@ -62,6 +79,7 @@ class InternVLImagePixelInputs(TensorSchema):
- h: Height of each image patch
- w: Width of each image patch
"""
type: Literal["pixel_values"]
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
@@ -74,13 +92,12 @@ class InternVLImageEmbeddingInputs(TensorSchema):
- f: Total image feature size
- h: Hidden size (must match the hidden size of language model backbone)
"""
type: Literal["image_embeds"]
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("n", "f", "h")]
data: Annotated[Union[torch.Tensor, list[torch.Tensor]], TensorShape("n", "f", "h")]
InternVLImageInputs = Union[InternVLImagePixelInputs,
InternVLImageEmbeddingInputs]
InternVLImageInputs = Union[InternVLImagePixelInputs, InternVLImageEmbeddingInputs]
class InternVLVideoPixelInputs(TensorSchema):
@@ -92,6 +109,7 @@ class InternVLVideoPixelInputs(TensorSchema):
- h: Height of each video frame
- w: Width of each video frame
"""
type: Literal["pixel_values_videos"]
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
@@ -104,25 +122,27 @@ class InternVLVideoEmbeddingInputs(TensorSchema):
- f: Total video feature size
- h: Hidden size (must match the hidden size of language model backbone)
"""
type: Literal["video_embeds"]
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("n", "f", "h")]
data: Annotated[Union[torch.Tensor, list[torch.Tensor]], TensorShape("n", "f", "h")]
InternVLVideoInputs = Union[InternVLVideoPixelInputs,
InternVLVideoEmbeddingInputs]
InternVLVideoInputs = Union[InternVLVideoPixelInputs, InternVLVideoEmbeddingInputs]
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
transform = T.Compose(
[
T.Lambda(lambda img: convert_image_mode(img, "RGB")),
T.Resize(
(input_size, input_size), interpolation=T.InterpolationMode.BICUBIC
),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD),
]
)
# Image transformation operations (which include tensor computations
# on the CPU) can occupy a substantial number of CPU cores, introducing
# overhead due to CPU contention. This issue becomes particularly
@@ -147,7 +167,7 @@ def find_closest_aspect_ratio(
height: int,
image_size: int,
) -> tuple[int, int]:
best_ratio_diff = float('inf')
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
@@ -182,10 +202,13 @@ def get_internvl_target_ratios(
min_num: int,
max_num: int,
) -> list[tuple[int, int]]:
target_ratios = {(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1) if min_num <= i * j <= max_num}
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if min_num <= i * j <= max_num
}
return sorted(target_ratios, key=lambda x: x[0] * x[1])
@@ -243,10 +266,12 @@ def dynamic_preprocess_internvl(
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = ((i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size)
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
@@ -349,7 +374,8 @@ class BaseInternVLProcessor(ABC):
assert isinstance(dynamic_image_size, bool)
self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2))
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
)
self.image_size = image_size
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
@@ -377,14 +403,18 @@ class BaseInternVLProcessor(ABC):
dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None,
) -> tuple[int, int]:
min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
is None else min_dynamic_patch)
max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
is None else max_dynamic_patch)
dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
is None else dynamic_image_size)
use_thumbnail = (self.use_thumbnail
if use_thumbnail is None else use_thumbnail)
min_dynamic_patch = (
self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch
)
max_dynamic_patch = (
self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch
)
dynamic_image_size = (
self.dynamic_image_size
if dynamic_image_size is None
else dynamic_image_size
)
use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail
return resolve_internvl_min_max_num(
min_dynamic_patch=min_dynamic_patch,
@@ -451,7 +481,8 @@ class BaseInternVLProcessor(ABC):
min_num=min_num,
max_num=max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
)
for image in images
]
def _preprocess_image(
@@ -472,10 +503,10 @@ class BaseInternVLProcessor(ABC):
dynamic_image_size=dynamic_image_size,
)
image_inputs = {
"pixel_values_flat":
torch.cat(pixel_values_lst),
"image_num_patches":
torch.tensor([len(item) for item in pixel_values_lst]),
"pixel_values_flat": torch.cat(pixel_values_lst),
"image_num_patches": torch.tensor(
[len(item) for item in pixel_values_lst]
),
}
for pixel_values in pixel_values_lst:
@@ -483,11 +514,10 @@ class BaseInternVLProcessor(ABC):
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
text = [t.replace('<image>', image_repl.full, 1) for t in text]
text = [t.replace("<image>", image_repl.full, 1) for t in text]
return text, image_inputs
def _make_batch_input(self,
input_item: Optional[Union[Any, list[Any]]] = None):
def _make_batch_input(self, input_item: Optional[Union[Any, list[Any]]] = None):
if input_item is None:
input_item = []
if not isinstance(input_item, list):
@@ -581,7 +611,8 @@ class InternVLProcessor(BaseInternVLProcessor):
min_num=min_num,
max_num=max_num,
use_thumbnail=False,
) for video in videos
)
for video in videos
]
def _preprocess_video(
@@ -598,18 +629,19 @@ class InternVLProcessor(BaseInternVLProcessor):
dynamic_image_size=dynamic_image_size,
)
video_inputs = {
"pixel_values_flat_video":
torch.cat(pixel_values_lst_video),
"video_num_patches":
torch.tensor([len(item) for item in pixel_values_lst_video]),
"pixel_values_flat_video": torch.cat(pixel_values_lst_video),
"video_num_patches": torch.tensor(
[len(item) for item in pixel_values_lst_video]
),
}
for pixel_values in pixel_values_lst_video:
num_patches = pixel_values.shape[0]
video_repl = self.get_video_repl(self.num_image_token,
num_patches, self.video_token)
text = [t.replace('<video>', video_repl.full, 1) for t in text]
video_repl = self.get_video_repl(
self.num_image_token, num_patches, self.video_token
)
text = [t.replace("<video>", video_repl.full, 1) for t in text]
return text, video_inputs
def __call__(
@@ -665,9 +697,9 @@ class InternVLProcessor(BaseInternVLProcessor):
repl_features = video_context_token * self.num_image_token
repl_features_with_sep = IMG_START + repl_features + IMG_END
# num_patches is equal to num_frames
repl_full = ''.join([
f'Frame{i+1}: {repl_features_with_sep}' for i in range(num_patches)
])
repl_full = "".join(
[f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)]
)
return PromptUpdateDetails.select_text(repl_full, video_context_token)
@@ -714,8 +746,7 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
largest_feature_pinpoint = ImageSize(width=width, height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
@@ -750,23 +781,23 @@ class BaseInternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_width, target_height = self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides)
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
)
}
class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
""" Basic image-only MultiModalProcessor for InternVL-style models."""
"""Basic image-only MultiModalProcessor for InternVL-style models."""
def _call_hf_processor(
self,
@@ -802,7 +833,8 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
"image", image_num_patches
),
image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
@@ -830,7 +862,8 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
def get_replacement_internvl(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
"image", (ImageEmbeddingItems, ImageProcessorItems)
)
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
@@ -889,8 +922,7 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo):
processor = self.get_hf_processor()
max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = (seq_len -
max_image_tokens) // processor.num_image_token
max_total_frames = (seq_len - max_image_tokens) // processor.num_image_token
max_frames_per_video = max_total_frames // max(max_videos, 1)
return max(max_frames_per_video, 1)
@@ -906,7 +938,8 @@ class InternVLProcessingInfo(BaseInternVLProcessingInfo):
class InternVLDummyInputsBuilder(
BaseInternVLDummyInputsBuilder[InternVLProcessingInfo]):
BaseInternVLDummyInputsBuilder[InternVLProcessingInfo]
):
"""InternVL DummyInputsBuilder extended for video support"""
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
@@ -920,23 +953,25 @@ class InternVLDummyInputsBuilder(
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict:
dummy_image = super().get_dummy_mm_data(seq_len=seq_len,
mm_counts=mm_counts,
mm_options=mm_options)
dummy_image = super().get_dummy_mm_data(
seq_len=seq_len, mm_counts=mm_counts, mm_options=mm_options
)
if self.info.supports_video:
config = self.info.get_hf_config()
image_size: int = config.vision_config.image_size
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts
)
num_videos = mm_counts.get("video", 0)
video_overrides = mm_options.get("video") if mm_options else None
dummy_video = {
"video":
self._get_dummy_videos(width=image_size,
height=image_size,
num_frames=target_num_frames,
num_videos=num_videos,
overrides=video_overrides)
"video": self._get_dummy_videos(
width=image_size,
height=image_size,
num_frames=target_num_frames,
num_videos=num_videos,
overrides=video_overrides,
)
}
else:
dummy_video = {}
@@ -944,7 +979,8 @@ class InternVLDummyInputsBuilder(
class InternVLMultiModalProcessor(
BaseInternVLMultiModalProcessor[InternVLProcessingInfo]):
BaseInternVLMultiModalProcessor[InternVLProcessingInfo]
):
"""InternVL MultiModalProcessor extended for video support"""
def _call_hf_processor(
@@ -954,12 +990,15 @@ class InternVLMultiModalProcessor(
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(prompt, mm_data,
mm_kwargs, tok_kwargs)
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs, tok_kwargs
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
if self.info.supports_video and (
video_token_id := hf_processor.video_token_id) is not None:
if (
self.info.supports_video
and (video_token_id := hf_processor.video_token_id) is not None
):
processed_outputs["video_token_id"] = torch.tensor(video_token_id)
return processed_outputs
@@ -968,18 +1007,16 @@ class InternVLMultiModalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_fields = super()._get_mm_fields_config(hf_inputs,
hf_processor_mm_kwargs)
image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
if self.info.supports_video:
video_num_patches = hf_inputs.get("video_num_patches",
torch.empty(0))
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
num_videos = len(video_num_patches)
video_fields = dict(
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches),
"video", video_num_patches
),
video_num_patches=MultiModalFieldConfig.batched("video"),
video_token_id=MultiModalFieldConfig.shared(
"video", num_videos),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
)
else:
video_fields = {}
@@ -1015,9 +1052,8 @@ class InternVLMultiModalProcessor(
assert isinstance(num_patches, int)
return hf_processor.get_video_repl(
feature_size,
num_patches,
video_context_token=hf_processor.video_token)
feature_size, num_patches, video_context_token=hf_processor.video_token
)
if self.info.supports_video:
prompt_repl = [
@@ -1026,7 +1062,7 @@ class InternVLMultiModalProcessor(
modality="video",
target="<video>",
replacement=get_video_replacement_internvl,
)
),
]
return prompt_repl
@@ -1035,9 +1071,9 @@ class InternVLMultiModalProcessor(
@MULTIMODAL_REGISTRY.register_processor(
InternVLMultiModalProcessor,
info=InternVLProcessingInfo,
dummy_inputs=InternVLDummyInputsBuilder)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
dummy_inputs=InternVLDummyInputsBuilder,
)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
merge_by_field_config = True
supports_encoder_tp_data = True
@@ -1067,12 +1103,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2))
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
)
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
self.llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
self.is_mono = self.llm_arch_name == "InternLM2VEForCausalLM"
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
@@ -1093,18 +1130,20 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self.visual_token_mask = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.language_model.make_empty_intermediate_tensors
)
def _patch_quant_config(self, config: PretrainedConfig,
quant_config: QuantizationConfig):
def _patch_quant_config(
self, config: PretrainedConfig, quant_config: QuantizationConfig
):
# the awq models from OpenGVLab missing `modules_to_not_convert`
# patch the quant_config to add `modules_to_not_convert` back
if isinstance(quant_config, AWQConfig):
text_config = config.text_config
llm_quant_config = getattr(text_config, "quantization_config",
None)
if (not quant_config.modules_to_not_convert) and \
(llm_quant_config is not None):
llm_quant_config = getattr(text_config, "quantization_config", None)
if (not quant_config.modules_to_not_convert) and (
llm_quant_config is not None
):
quant_config.modules_to_not_convert.append("vision_model")
def _init_vision_model(
@@ -1118,8 +1157,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
if not is_mono:
vision_feature_layer = config.select_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
num_hidden_layers = (
config.vision_config.num_hidden_layers + vision_feature_layer + 1
)
else:
num_hidden_layers = vision_feature_layer + 1
@@ -1128,7 +1168,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
use_data_parallel=self.use_data_parallel)
use_data_parallel=self.use_data_parallel,
)
else:
return InternVisionPatchModel(config.vision_config)
@@ -1137,9 +1178,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
llm_hidden_size = config.text_config.hidden_size
return nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
llm_hidden_size),
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
nn.Linear(
vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
),
nn.GELU(),
nn.Linear(llm_hidden_size, llm_hidden_size),
)
@@ -1150,9 +1192,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
if self.ps_version == 'v1':
x = x.view(
n,
int(h * scale_factor),
int(w * scale_factor),
int(c / (scale_factor * scale_factor)),
)
if self.ps_version == "v1":
pass
else:
x = x.permute(0, 2, 1, 3).contiguous()
@@ -1162,17 +1208,16 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
vit_embeds = self.vision_model(pixel_values=pixel_values)
vit_embeds = vit_embeds[:, 1:, :]
h = w = int(vit_embeds.shape[1]**0.5)
h = w = int(vit_embeds.shape[1] ** 0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds,
scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
vit_embeds.shape[-1])
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternVLImageInputs]:
self, **kwargs: object
) -> Optional[InternVLImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
@@ -1204,7 +1249,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[InternVLVideoPixelInputs]:
self, **kwargs: object
) -> Optional[InternVLVideoPixelInputs]:
pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
video_num_patches = kwargs.pop("video_num_patches", None)
video_embeds = kwargs.pop("image_embeds", None)
@@ -1239,8 +1285,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self,
image_input: Union[InternVLImageInputs, InternVLVideoInputs],
) -> tuple[torch.Tensor, ...]:
if (image_input["type"] == "image_embeds"
or image_input["type"] == "video_embeds"):
if (
image_input["type"] == "image_embeds"
or image_input["type"] == "video_embeds"
):
return image_input["data"]
assert self.vision_model is not None
@@ -1251,14 +1299,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
# Only one image in the current batch
if len(num_patches) == 1:
return (image_embeds.view(-1,
self.config.text_config.hidden_size), )
return (image_embeds.view(-1, self.config.text_config.hidden_size),)
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size = image_embeds.shape[1]
image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size)
image_embeds = image_embeds.view(-1, self.config.text_config.hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in num_patches
]
@@ -1270,31 +1316,29 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values_flat",
"image_embeds") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key in ("pixel_values_flat_video",
) and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
if (
input_key in ("pixel_values_flat", "image_embeds")
and "images" not in modalities
):
modalities["images"] = self._parse_and_validate_image_input(**kwargs)
if input_key in ("pixel_values_flat_video",) and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(**kwargs)
return modalities
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
if self.is_mono:
assert self.img_context_token_id is not None
self.visual_token_mask = (
input_ids == self.img_context_token_id).reshape(-1, 1)
self.visual_token_mask = (input_ids == self.img_context_token_id).reshape(
-1, 1
)
else:
self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
@@ -1325,8 +1369,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
is_multimodal: Optional[torch.Tensor] = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(
multimodal_embeddings) > 0:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
self._set_visual_token_mask(input_ids)
# This is to satisfy the type checker for each overload
@@ -1348,7 +1391,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
@@ -1362,8 +1404,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
# Only required if the model is mono-architecture
if self.visual_token_mask is not None:
forward_kwargs.update(
{"visual_token_mask": self.visual_token_mask})
forward_kwargs.update({"visual_token_mask": self.visual_token_mask})
self.visual_token_mask = None
hidden_states = self.language_model.model(**forward_kwargs)
@@ -1375,14 +1416,21 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# unused modules appear in OpenGVLab/InternVideo2_5_Chat_8B
skip_prefixes = [
"action_embed", "temporal_embed", "track_embed",
"track_embed_decoder", "box_token", "cg_criterion", "cg_model",
"loc_encoder", "loc_decoder", "sam", "temporal_token",
"track_token"
"action_embed",
"temporal_embed",
"track_embed",
"track_embed_decoder",
"box_token",
"cg_criterion",
"cg_model",
"loc_encoder",
"loc_decoder",
"sam",
"temporal_token",
"track_token",
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
@@ -1394,4 +1442,5 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="mlp1",
tower_model="vision_model")
tower_model="vision_model",
)