[vlm] Remove vision language config. (#6089)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
xwjiang2010
2024-07-03 15:14:16 -07:00
committed by GitHub
parent 3c6325f0fc
commit d9e98f42e4
43 changed files with 371 additions and 465 deletions

View File

@@ -3,7 +3,7 @@ from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
from typing_extensions import TypeGuard
from vllm.config import LoRAConfig, VisionLanguageConfig
from vllm.config import LoRAConfig, MultiModalConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -22,7 +22,7 @@ class SupportsVision(Protocol):
MRO of your model class.
"""
def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
...
@@ -32,7 +32,7 @@ class SupportsVision(Protocol):
class _SupportsVisionType(Protocol):
supports_vision: Literal[True]
def __call__(self, *, vlm_config: VisionLanguageConfig) -> None:
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
...

View File

@@ -5,7 +5,7 @@ import torch.nn as nn
from transformers import CLIPVisionConfig, LlavaConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -108,13 +108,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,
config: LlavaConfig,
vlm_config: VisionLanguageConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.vlm_config = vlm_config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config.vision_config)
@@ -138,14 +138,13 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self.sampler = Sampler()
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
if list(data.shape)[1:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"{self.vlm_config.image_input_shape[1:]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
"The expected image tensor shape is batch dimension plus "
"channel, height and width.")
return data
@@ -244,7 +243,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vlm_config.image_token_id)
self.config.image_token_index)
input_ids = None
else:

View File

@@ -1,4 +1,4 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@@ -9,7 +9,7 @@ from transformers.models.llava_next.modeling_llava_next import (
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -204,13 +204,13 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,
config: LlavaNextConfig,
vlm_config: VisionLanguageConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.vlm_config = vlm_config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config=config.vision_config)
@@ -244,6 +244,47 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
def _validate_shape(data: torch.Tensor):
dim = data.dim()
height = width = self.config.vision_config.image_size
# All 4d image tensors have the same number of patches,
# so data is a 5d batch of these tensors
if dim == 5:
if list(data.shape)[2:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
"Expected pixel value tensor in shape of: (batch size, "
f"patch number, 3, {height}, {width}), got {data.shape}"
)
# 4d image tensors have different number of patches,
# so data is each individual tensor.
elif dim == 4:
if list(data.shape)[1:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
"Expected pixel value tensor in shape of: (patch "
f"number, 3, {height}, {width}), got {data.shape}")
else:
raise ValueError(
f"Invalid pixel value tensor of shape {data.shape}")
if isinstance(data, torch.Tensor):
_validate_shape(data)
else:
[_validate_shape(d) for d in data]
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@@ -262,7 +303,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return LlavaNextImagePixelInputs(
type="pixel_values",
data=pixel_values,
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
@@ -454,7 +495,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vlm_config.image_token_id)
self.config.image_token_index)
input_ids = None
else:

View File

@@ -15,7 +15,7 @@
# limitations under the License.
import re
from functools import lru_cache
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import numpy as np
import torch
@@ -24,7 +24,7 @@ from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -50,6 +50,9 @@ _KEYS_TO_MODIFY_MAPPING = {
"model.vision_embed_tokens": "vision_embed_tokens",
}
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
hidden_act="quick_gelu",
hidden_size=1024,
@@ -95,13 +98,10 @@ class Phi3ImageEmbeddingBase(nn.Module):
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform."""
def __init__(self,
vision_language_config: VisionLanguageConfig,
config: PretrainedConfig,
wte=None) -> None:
def __init__(self, config: PretrainedConfig, wte=None) -> None:
super().__init__(wte)
self.image_token_id = vision_language_config.image_token_id
self.image_token_id = _IMAGE_TOKEN_ID
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
@@ -333,7 +333,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
image_token_id=32044,
image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
@@ -370,7 +370,6 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
return llm_inputs
model_config = ctx.model_config
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(PretrainedConfig)
image_data = multi_modal_data["image"]
@@ -407,7 +406,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
new_token_ids: List[int] = []
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
new_token_ids.append(multimodal_config.image_token_id)
new_token_ids.append(_IMAGE_TOKEN_ID)
# No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
@@ -424,7 +423,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
model_config,
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
llm_inputs,
image_token_id=multimodal_config.image_token_id,
image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size,
)
@@ -436,25 +435,53 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
def __init__(self,
config: PretrainedConfig,
vlm_config: VisionLanguageConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.vlm_config = vlm_config
self.multimodal_config = multimodal_config
self.model = LlamaModel(config, cache_config, quant_config)
# TODO: Optionally initializes this for supporting embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
vlm_config, config, self.model.embed_tokens)
config, self.model.embed_tokens)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
f"The expected image sizes shape is batch dimension plus "
f"{[2]}. You supplied {data.shape}.")
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
def _validate_shape(data: torch.Tensor):
if list(data.shape)[2:] != [
3, CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
]:
raise ValueError(
"The expected pixel value tensor shape is batch dimension "
"plus patch number, channel, height and width.")
if isinstance(data, torch.Tensor):
_validate_shape(data)
else:
[_validate_shape(d) for d in data]
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@@ -471,9 +498,10 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return Phi3VImagePixelInputs(type="pixel_values",
data=pixel_values,
image_sizes=image_sizes)
return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))
def forward(self,
input_ids: torch.Tensor,