[vlm] Remove vision language config. (#6089)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
xwjiang2010
2024-07-03 15:14:16 -07:00
committed by GitHub
parent 3c6325f0fc
commit d9e98f42e4
43 changed files with 371 additions and 465 deletions

View File

@@ -1,8 +1,7 @@
import enum
import json
from dataclasses import dataclass, field, fields
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Tuple,
Union)
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
import torch
from transformers import PretrainedConfig
@@ -120,7 +119,7 @@ class ModelConfig:
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
multimodal_config: Optional["VisionLanguageConfig"] = None,
multimodal_config: Optional["MultiModalConfig"] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
@@ -1289,35 +1288,12 @@ class LoRAConfig:
raise ValueError("LoRA is not supported with chunked prefill yet.")
# TODO: To be replaced by MultiModalConfig.
@dataclass
class VisionLanguageConfig:
class MultiModalConfig:
"""Configs the input data format and how models should run for
vision language models."""
# The input id corresponding to image token.
image_token_id: int
# Used for running `run_prefill_max_token`.
# For models that support varying resolution, this corresponds to
# worst case scenario (biggest supported resolution).
image_input_shape: tuple
image_feature_size: int
def as_cli_args_dict(self) -> Dict[str, Any]:
"""Flatten vision language config to pure args.
Compatible with what llm entrypoint expects.
"""
result: Dict[str, Any] = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, enum.Enum):
result[f.name] = value.name.lower()
elif isinstance(value, tuple):
result[f.name] = ",".join([str(item) for item in value])
else:
result[f.name] = value
return result
multimodal models."""
# TODO: Add configs to init vision tower or not.
pass
_STR_DTYPE_TO_TORCH_DTYPE = {
@@ -1541,7 +1517,7 @@ class EngineConfig:
device_config: DeviceConfig
load_config: LoadConfig
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
multimodal_config: Optional[MultiModalConfig]
speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig]