[VLM] Various cleanup and fixes (#14806)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-14 20:58:19 +08:00
committed by GitHub
parent 40253bab44
commit ab93f1360f
14 changed files with 283 additions and 273 deletions

View File

@@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict
from typing import Literal, Optional, Set, Tuple, TypedDict
import torch
import torch.nn as nn
@@ -31,8 +31,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -58,10 +57,12 @@ class FuyuImagePatchInputs(TypedDict):
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
patches_per_image: List[int]
patches_per_image: list[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `flat_data`.
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
"""
@@ -317,7 +318,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return None
def _process_image_input(
self, image_input: FuyuImagePatchInputs) -> NestedTensors:
self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]