Files
vllm/vllm/model_executor/models/llava.py

374 lines
15 KiB
Python
Raw Normal View History

2024-05-20 16:11:25 +08:00
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
import torch.nn as nn
from transformers import CLIPVisionConfig, LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.sequence import SamplerOutput
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
dummy_seq_data_for_clip)
from .interfaces import SupportsVision
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):
def __init__(self, vision_hidden_size: int, text_hidden_size: int,
projector_hidden_act: str):
super().__init__()
self.linear_1 = nn.Linear(vision_hidden_size,
text_hidden_size,
bias=True)
self.act = get_act_fn(projector_hidden_act)
self.linear_2 = nn.Linear(text_hidden_size,
text_hidden_size,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int) -> torch.Tensor:
"""In place merges in vision_embeddings with inputs_embeds."""
mask = (input_ids == image_token_id)
image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
if mask.sum() != image_feature_size:
raise ValueError(f"image_feature_size should be {image_feature_size}, "
f"but found: {mask.sum()}")
inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
vision_embeddings.shape[-1])
return inputs_embeds
2024-05-20 16:11:25 +08:00
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
class LlavaImageFeatureInputs(TypedDict):
type: Literal["image_features"]
data: torch.Tensor
"""Shape: (batch_size, image_feature_size, hidden_size)"""
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
image_token_id=hf_config.image_token_index,
)
image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
mm_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
mm_data = dummy_pixel_data_for_clip(vision_config)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
mm_data = dummy_feature_data_for_clip(vision_config)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,
config: LlavaConfig,
vlm_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.vlm_config = vlm_config
if self.vlm_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
else:
self.vision_tower = None
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.quant_config = quant_config
self.language_model = LlamaModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
2024-05-20 16:11:25 +08:00
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
2024-05-20 16:11:25 +08:00
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"{self.vlm_config.image_input_shape[1:]}. "
2024-05-20 16:11:25 +08:00
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
return data
def _parse_and_validate_image_input(
2024-06-03 13:56:41 +08:00
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_features = kwargs.pop("image_features", None)
expected_input_type = self.vlm_config.image_input_type
2024-05-20 16:11:25 +08:00
ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type == ImageInputType.PIXEL_VALUES:
2024-06-03 13:56:41 +08:00
if image_features is not None:
raise ValueError(
"Expected pixel values but got image features")
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
2024-05-20 16:11:25 +08:00
return LlavaImagePixelInputs(
type="pixel_values",
2024-06-03 13:56:41 +08:00
data=self._validate_image_data(pixel_values),
2024-05-20 16:11:25 +08:00
)
2024-06-03 13:56:41 +08:00
if expected_input_type == ImageInputType.IMAGE_FEATURES:
if pixel_values is not None:
raise ValueError(
"Expected image features but got pixel values")
if image_features is None:
return None
if not isinstance(image_features, torch.Tensor):
raise ValueError("Incorrect type of image features. "
f"Got type: {type(image_features)}")
2024-05-20 16:11:25 +08:00
return LlavaImageFeatureInputs(
type="image_features",
2024-06-03 13:56:41 +08:00
data=self._validate_image_data(image_features),
2024-05-20 16:11:25 +08:00
)
return None
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values,
self.config.vision_feature_layer)
2024-05-20 16:11:25 +08:00
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)
def _process_image_pixels(self,
inputs: LlavaImagePixelInputs) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["data"]
return self._image_pixels_to_features(self.vision_tower, pixel_values)
def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "pixel_values":
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
else:
image_features = image_input["data"]
return self.multi_modal_projector(image_features)
2024-06-03 13:56:41 +08:00
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for LLaVA-1.5.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
Tokenizer outputs:
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
The to-be-inserted image has a size of 576 (24 * 24) along the context
length dimension.
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
9047, 13566, 29901].
There will be 576 `32000` in the `input_ids`.
(32000 is the token id for `<image>`.)
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
This model has two modes of image inputs:
`PIXEL_VALUES` and `IMAGE_FEATURES`.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
Expects a batch with shape `[1, 3, 336, 336]`.
(Only applicable to `PIXEL_VALUES` mode)
image_features: The image features for each input image outputted by
the vision tower before passing to the multi-modal projector.
Expects a batch with shape `[1, 576, 1024]`.
(Only applicable to `IMAGE_FEATURES` mode)
See also:
Each input maps to huggingface implementation, as follows:
- `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L360
- `image_features`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L437
"""
2024-06-03 13:56:41 +08:00
image_input = self._parse_and_validate_image_input(**kwargs)
2024-05-20 16:11:25 +08:00
2024-06-03 13:56:41 +08:00
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
2024-05-20 16:11:25 +08:00
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vlm_config.image_token_id)
2024-05-20 16:11:25 +08:00
input_ids = None
else:
inputs_embeds = None
2024-05-20 16:11:25 +08:00
hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)