[vlm] Remove vision language config. (#6089)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -24,7 +24,7 @@ from PIL import Image
|
||||
from transformers import CLIPVisionConfig, PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -50,6 +50,9 @@ _KEYS_TO_MODIFY_MAPPING = {
|
||||
"model.vision_embed_tokens": "vision_embed_tokens",
|
||||
}
|
||||
|
||||
# Cannot find the following 2 numbers from hf config.
|
||||
_IMAGE_TOKEN_ID = 32044
|
||||
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
||||
hidden_act="quick_gelu",
|
||||
hidden_size=1024,
|
||||
@@ -95,13 +98,10 @@ class Phi3ImageEmbeddingBase(nn.Module):
|
||||
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
"""Phi3 Image embedding with HD transform."""
|
||||
|
||||
def __init__(self,
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
config: PretrainedConfig,
|
||||
wte=None) -> None:
|
||||
def __init__(self, config: PretrainedConfig, wte=None) -> None:
|
||||
super().__init__(wte)
|
||||
|
||||
self.image_token_id = vision_language_config.image_token_id
|
||||
self.image_token_id = _IMAGE_TOKEN_ID
|
||||
# n_embed or hidden_size
|
||||
hidden_size = config.n_embd if hasattr(
|
||||
config, 'n_embd') else config.hidden_size
|
||||
@@ -333,7 +333,7 @@ def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
seq_len,
|
||||
image_token_id=32044,
|
||||
image_token_id=_IMAGE_TOKEN_ID,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_image_for_clip(
|
||||
@@ -370,7 +370,6 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
@@ -407,7 +406,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
new_token_ids: List[int] = []
|
||||
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
|
||||
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
|
||||
new_token_ids.append(multimodal_config.image_token_id)
|
||||
new_token_ids.append(_IMAGE_TOKEN_ID)
|
||||
|
||||
# No need to further scan the list since we only replace once
|
||||
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
|
||||
@@ -424,7 +423,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
model_config,
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
llm_inputs,
|
||||
image_token_id=multimodal_config.image_token_id,
|
||||
image_token_id=_IMAGE_TOKEN_ID,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
@@ -436,25 +435,53 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vlm_config = vlm_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.model = LlamaModel(config, cache_config, quant_config)
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
vlm_config, config, self.model.embed_tokens)
|
||||
config, self.model.embed_tokens)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if list(data.shape[1:]) != [2]:
|
||||
raise ValueError(
|
||||
f"The expected image sizes shape is batch dimension plus "
|
||||
f"{[2]}. You supplied {data.shape}.")
|
||||
|
||||
return data
|
||||
|
||||
def _validate_pixel_values(
|
||||
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
def _validate_shape(data: torch.Tensor):
|
||||
if list(data.shape)[2:] != [
|
||||
3, CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
|
||||
]:
|
||||
raise ValueError(
|
||||
"The expected pixel value tensor shape is batch dimension "
|
||||
"plus patch number, channel, height and width.")
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
_validate_shape(data)
|
||||
else:
|
||||
[_validate_shape(d) for d in data]
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@@ -471,9 +498,10 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
return Phi3VImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
image_sizes=image_sizes)
|
||||
return Phi3VImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
image_sizes=self._validate_image_sizes(image_sizes))
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user