[vlm] Remove vision language config. (#6089)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
xwjiang2010
2024-07-03 15:14:16 -07:00
committed by GitHub
parent 3c6325f0fc
commit d9e98f42e4
43 changed files with 371 additions and 465 deletions

View File

@@ -1,4 +1,4 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@@ -9,7 +9,7 @@ from transformers.models.llava_next.modeling_llava_next import (
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -204,13 +204,13 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,
config: LlavaNextConfig,
vlm_config: VisionLanguageConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__()
self.config = config
self.vlm_config = vlm_config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config=config.vision_config)
@@ -244,6 +244,47 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
def _validate_shape(data: torch.Tensor):
dim = data.dim()
height = width = self.config.vision_config.image_size
# All 4d image tensors have the same number of patches,
# so data is a 5d batch of these tensors
if dim == 5:
if list(data.shape)[2:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
"Expected pixel value tensor in shape of: (batch size, "
f"patch number, 3, {height}, {width}), got {data.shape}"
)
# 4d image tensors have different number of patches,
# so data is each individual tensor.
elif dim == 4:
if list(data.shape)[1:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
raise ValueError(
"Expected pixel value tensor in shape of: (patch "
f"number, 3, {height}, {width}), got {data.shape}")
else:
raise ValueError(
f"Invalid pixel value tensor of shape {data.shape}")
if isinstance(data, torch.Tensor):
_validate_shape(data)
else:
[_validate_shape(d) for d in data]
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@@ -262,7 +303,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return LlavaNextImagePixelInputs(
type="pixel_values",
data=pixel_values,
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
@@ -454,7 +495,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vlm_config.image_token_id)
self.config.image_token_index)
input_ids = None
else: