[VLM] Separate text-only and vision variants of the same model architecture (#13157)
This commit is contained in:
@@ -1,20 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/CogAgent
|
||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||
from argparse import Namespace
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import PreTrainedTokenizer, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@@ -31,204 +23,14 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, BatchFeature,
|
||||
MultiModalFieldConfig,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
|
||||
class GLMImagePixelInputs(TypedDict):
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||
|
||||
|
||||
class GLM4VProcessor:
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if vision_config := getattr(config, "vision_config", None):
|
||||
image_size = vision_config["image_size"]
|
||||
|
||||
self.image_transform = transforms.Compose([
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711),
|
||||
),
|
||||
])
|
||||
else:
|
||||
self.image_transform = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
||||
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
text_inputs = self.tokenizer(text)
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
if self.image_transform is None:
|
||||
raise ValueError("This model does not support image inputs")
|
||||
|
||||
pixel_values = [self.image_transform(image) for image in images]
|
||||
image_inputs = {"pixel_values": torch.stack(pixel_values)}
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
class GLM4VProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = self.ctx.tokenizer
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
return tokenizer
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(ChatGLMConfig)
|
||||
|
||||
def get_hf_processor(self) -> GLM4VProcessor:
|
||||
return GLM4VProcessor(
|
||||
self.get_hf_config(),
|
||||
self.get_tokenizer(),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_num_image_feature_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
if not (vision_config := getattr(hf_config, "vision_config", None)):
|
||||
return 0
|
||||
|
||||
image_size = vision_config["image_size"]
|
||||
patch_size = vision_config["patch_size"]
|
||||
grid_length = image_size // patch_size // 2
|
||||
return grid_length * grid_length
|
||||
|
||||
def get_num_image_feature_tokens(self) -> int:
|
||||
# EVA2CLIPModel has embeddings for boi and eoi tokens as well
|
||||
return self.get_num_image_tokens() + 2
|
||||
|
||||
|
||||
class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.info.get_hf_config()
|
||||
if not (vision_config := getattr(hf_config, "vision_config", None)):
|
||||
return ProcessorInputs(prompt_text="", mm_data={})
|
||||
|
||||
target_width = target_height = vision_config["image_size"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
base_text = "<|begin_of_image|><|endoftext|><|end_of_image|>"
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text=base_text * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
if not hasattr(hf_config, "vision_config"):
|
||||
return []
|
||||
|
||||
boi_token_id = hf_config.boi_token_id
|
||||
image_token_id = hf_config.pad_token_id
|
||||
eoi_token_id = hf_config.eoi_token_id
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [image_token_id] * num_image_tokens
|
||||
|
||||
return [boi_token_id] + image_tokens + [eoi_token_id]
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[boi_token_id, image_token_id, eoi_token_id],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
@@ -489,7 +291,7 @@ class GLMTransformer(nn.Module):
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states = layer(
|
||||
@@ -498,8 +300,12 @@ class GLMTransformer(nn.Module):
|
||||
kv_cache=kv_caches[i - self.start_layer],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
|
||||
# Final layer norm.
|
||||
if get_pp_group().is_last_rank and self.post_layer_norm:
|
||||
if self.post_layer_norm:
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -534,61 +340,11 @@ class ChatGLMModel(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output_layer")
|
||||
|
||||
vision_config_flag = getattr(config, 'vision_config', None)
|
||||
if vision_config_flag is not None:
|
||||
self.vision_config = Namespace(**config.vision_config)
|
||||
self.vision = EVA2CLIPModel(self.config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.vision")
|
||||
else:
|
||||
self.vision = None
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.encoder.make_empty_intermediate_tensors)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> GLMImagePixelInputs:
|
||||
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if pixel_values is not None and self.vision is not None:
|
||||
if isinstance(pixel_values, torch.Tensor):
|
||||
if pixel_values.ndim > 2:
|
||||
pixel_values = torch.concat(list(pixel_values))
|
||||
elif isinstance(pixel_values, list):
|
||||
return torch.concat(pixel_values)
|
||||
else:
|
||||
raise TypeError("""pixel_values must be a torch.Tensor
|
||||
or a list of torch.Tensor
|
||||
""")
|
||||
return GLMImagePixelInputs(pixel_values=pixel_values)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input["pixel_values"] is None:
|
||||
return None
|
||||
pixel_values = image_input["pixel_values"].to(
|
||||
dtype=self.config.torch_dtype)
|
||||
vision_embeddings = self.vision(pixel_values)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=[
|
||||
self.config.boi_token_id,
|
||||
self.config.pad_token_id,
|
||||
self.config.eoi_token_id,
|
||||
],
|
||||
)
|
||||
return inputs_embeds
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embedding(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -599,26 +355,24 @@ class ChatGLMModel(nn.Module):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
# Run encoder.
|
||||
hidden_states = self.encoder(
|
||||
hidden_states=inputs_embeds,
|
||||
hidden_states=hidden_states,
|
||||
position_ids=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states})
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
@@ -660,12 +414,18 @@ class ChatGLMModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class ChatGLMBaseModel(nn.Module):
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={".word_embeddings": ""}, )
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
transformer_type: type[ChatGLMModel] = ChatGLMModel,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
@@ -678,27 +438,17 @@ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.quant_config = quant_config
|
||||
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
||||
8192)
|
||||
self.transformer = ChatGLMModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
self.transformer = transformer_type(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.transformer.output_layer.weight = (
|
||||
self.transformer.embedding.weight)
|
||||
self.lm_head = self.transformer.output_layer
|
||||
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
**kwargs)
|
||||
return hidden_states
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.transformer.make_empty_intermediate_tensors)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
@@ -722,7 +472,7 @@ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
class ChatGLM(ChatGLMBaseModel):
|
||||
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
@@ -738,82 +488,28 @@ class ChatGLM(ChatGLMBaseModel):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
if hasattr(config, "vision_config"):
|
||||
hf_overrides = {"architectures": ["GLM4VForCausalLM"]}
|
||||
raise RuntimeError(
|
||||
"The configuration of this model indicates that it supports "
|
||||
"vision inputs, but you instantiated the text-only version "
|
||||
"of this model. Please use the vision model by setting "
|
||||
f"`--hf-overrides {hf_overrides!r}`")
|
||||
|
||||
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"],
|
||||
"merged_proj": ["gate_proj", "dense_h_to_4h"]
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"query_key_value",
|
||||
"dense",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h",
|
||||
# vision
|
||||
"fc1",
|
||||
"fc2",
|
||||
"merged_proj",
|
||||
"linear_proj"
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="transformer.encoder",
|
||||
connector="transformer.vision.linear_proj",
|
||||
tower_model="transformer.vision.transformer")
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
return self.transformer.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
def get_input_embeddings(
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.transformer.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor,
|
||||
info=GLM4VProcessingInfo,
|
||||
dummy_inputs=GLM4VDummyInputsBuilder)
|
||||
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
SupportsMultiModal):
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
# initialized, but set all these attributes to empty.
|
||||
# These will be updated when an instance class is selected
|
||||
packed_modules_mapping = {}
|
||||
supported_lora_modules = []
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
# Initialize VL
|
||||
if hasattr(config, "vision_config"): # noqa: SIM108
|
||||
instance_cls = ChatGLMV
|
||||
# Initialize LLM
|
||||
else:
|
||||
instance_cls = ChatGLM
|
||||
|
||||
# quant_config references base class members,
|
||||
# so update values before init is called
|
||||
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
|
||||
cls.supported_lora_modules += instance_cls.supported_lora_modules
|
||||
cls.embedding_modules.update(instance_cls.embedding_modules)
|
||||
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
|
||||
return instance_cls(vllm_config=vllm_config, prefix=prefix)
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user