[Bugfix] Clean up and fix multi-modal processors (#13012)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-02-10 18:45:21 +08:00
committed by GitHub
parent fde71262e0
commit 51f0b5f7f6
7 changed files with 124 additions and 154 deletions

View File

@@ -4,8 +4,8 @@
# https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace
from typing import (Iterable, List, Mapping, Optional, Sequence, Set, Tuple,
TypedDict, Union)
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
import torch
from torch import nn
@@ -19,7 +19,6 @@ from transformers.tokenization_utils_base import TextInput
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -37,12 +36,10 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BatchFeature,
BoundPromptReplacement,
MultiModalFieldConfig,
PlaceholderFeaturesInfo,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -53,39 +50,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings)
logger = init_logger(__name__)
IMAGE_TOKEN_ID = 151329
def build_normalization_transform(image_size: int) -> transforms.Compose:
"""
Build a normalization transform which can be applied to one or
more input images from which we want to extract visual features.
Args:
image_size: size of the image to be processed for visual embeddings.
Returns:
Callable transform for normalizing and resizing one RGB image.
"""
return transforms.Compose([
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
])
def calculate_image_placeholder(vision_config):
return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2
class GLMImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
@@ -109,9 +73,20 @@ class GLM4VProcessor:
self.config = config
self.tokenizer = tokenizer
if hasattr(self.config, "vision_config"):
self.image_transform = build_normalization_transform(
config.vision_config["image_size"])
if vision_config := getattr(config, "vision_config", None):
image_size = vision_config["image_size"]
self.image_transform = transforms.Compose([
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
])
else:
self.image_transform = None
@@ -150,9 +125,19 @@ class GLM4VProcessor:
class GLM4VProcessingInfo(BaseProcessingInfo):
def __init__(self, ctx):
super().__init__(ctx)
self._pre_calculate()
def get_tokenizer(self):
tokenizer = self.ctx.tokenizer
assert isinstance(tokenizer, PreTrainedTokenizer)
return tokenizer
def get_hf_config(self):
return self.ctx.get_hf_config(ChatGLMConfig)
def get_hf_processor(self) -> GLM4VProcessor:
return GLM4VProcessor(
self.get_hf_config(),
self.get_tokenizer(),
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
@@ -162,27 +147,21 @@ class GLM4VProcessingInfo(BaseProcessingInfo):
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.image_token_num + 2}
def _pre_calculate(self):
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
self.image_token_num = calculate_image_placeholder(vision_config)
self.image_size = vision_config["image_size"]
return {"image": self.get_num_image_feature_tokens()}
def get_num_image_tokens(self) -> int:
return self.image_token_num + 2
hf_config = self.get_hf_config()
if not (vision_config := getattr(hf_config, "vision_config", None)):
return 0
def get_image_size(self) -> ImageSize:
image_size = vision_config["image_size"]
patch_size = vision_config["patch_size"]
grid_length = image_size // patch_size // 2
return grid_length * grid_length
return ImageSize(height=self.image_size, width=self.image_size)
def get_hf_processor(self) -> GLM4VProcessor:
return GLM4VProcessor(
self.get_hf_config(),
self.get_tokenizer(),
)
def get_num_image_feature_tokens(self) -> int:
# EVA2CLIPModel has embeddings for boi and eoi tokens as well
return self.get_num_image_tokens() + 2
class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
@@ -192,8 +171,12 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.info.get_hf_config()
if not (vision_config := getattr(hf_config, "vision_config", None)):
return ProcessorInputs(prompt_text="", mm_data={})
target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size()
mm_data = {
"image":
@@ -201,9 +184,11 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
height=target_height,
num_images=num_images)
}
text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
return ProcessorInputs(
prompt_text=text,
prompt_text=base_text * num_images,
mm_data=mm_data,
)
@@ -223,47 +208,28 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
if not hasattr(hf_config, "vision_config"):
return []
boi_token_id = hf_config.boi_token_id
image_token_id = hf_config.pad_token_id
eoi_token_id = hf_config.eoi_token_id
def get_replacement(item_idx: int):
image_tokens = self.info.image_token_num
return [IMAGE_TOKEN_ID] * image_tokens
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens
return [boi_token_id] + image_tokens + [eoi_token_id]
return [
PromptReplacement(
modality="image",
target=[IMAGE_TOKEN_ID],
target=[boi_token_id, image_token_id, eoi_token_id],
replacement=get_replacement,
),
]
def _apply_prompt_replacements(
self,
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls,
mm_item_counts=mm_item_counts,
)
hf_config = self.info.get_hf_config()
boi_token_id = hf_config.boi_token_id
eoi_token_id = hf_config.eoi_token_id
placeholders = {
modality: [
PlaceholderFeaturesInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=p.start_idx - 1,
tokens=[boi_token_id] + p.tokens + [eoi_token_id],
) for p in ps
]
for modality, ps in placeholders.items()
}
return token_ids, text, placeholders
class GLMAttention(nn.Module):
@@ -618,7 +584,7 @@ class ChatGLMModel(nn.Module):
multimodal_embeddings=multimodal_embeddings,
placeholder_token_id=[
self.config.boi_token_id,
IMAGE_TOKEN_ID,
self.config.pad_token_id,
self.config.eoi_token_id,
],
)