[Model] Refactor Qwen2-VL to use merged multimodal processor (#11258)

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Isotr0py
2024-12-20 00:28:00 +08:00
committed by GitHub
parent 7379b3d4b2
commit e24113a8fe
5 changed files with 272 additions and 522 deletions

View File

@@ -164,7 +164,9 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
audio_len = get_max_qwen2_audio_audio_tokens(self.ctx)
feature_extractor = self._get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
audio_count = mm_counts["audio"]
audio = np.zeros(audio_len)

View File

@@ -22,28 +22,26 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import cached_property, partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Set, Tuple, Type, TypedDict, Union)
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, Type, TypedDict, Union)
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from PIL import Image
from transformers.image_utils import (get_image_size,
infer_channel_dimension_format,
to_numpy_array)
from transformers import BatchFeature
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
Qwen2VLProcessor)
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLConfig, Qwen2VLVisionConfig)
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
make_batched_images, make_batched_videos, smart_resize)
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU
@@ -56,14 +54,14 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.multimodal.inputs import MultiModalDataDict, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
@@ -159,7 +157,7 @@ class Qwen2VisionMLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int = None,
hidden_features: int,
act_layer: Type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@@ -644,78 +642,8 @@ class Qwen2VisionTransformer(nn.Module):
# === Vision input helpers === #
def get_mm_processor_kwargs(
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None) -> Dict[str, int]:
mm_processor_kwargs = {}
if min_pixels:
mm_processor_kwargs["min_pixels"] = min_pixels
if max_pixels:
mm_processor_kwargs["max_pixels"] = max_pixels
return mm_processor_kwargs
def mm_input_mapper_for_qwen2_vl(
ctx: InputContext,
data: MultiModalData[object],
data_type_key: str,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
) -> MultiModalKwargs:
"""Input mapper for Qwen2-VL."""
if data_type_key == "image" and isinstance(data, dict):
return MultiModalKwargs({
"image_embeds": data.get("image_embeds"),
"image_grid_thw": data.get("image_grid_thw"),
})
if data_type_key == "video" and isinstance(data, dict):
return MultiModalKwargs({
"video_embeds": data.get("video_embeds"),
"video_grid_thw": data.get("video_grid_thw"),
})
model_config = ctx.model_config
# Handle mm processor kwargs; we pass these at creation time
# because preprocess() in transformers doesn't expose them
mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels,
max_pixels=max_pixels)
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs,
)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
images = None
videos = None
if data_type_key == "image":
images = data
else:
assert data_type_key == "video"
videos = data
try:
batch_data = image_processor \
.preprocess(images=images, videos=videos, return_tensors="pt") \
.data
except Exception:
logger.error("Failed to process image (%s)", data)
raise
return MultiModalKwargs(batch_data)
image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl,
data_type_key="image")
video_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl,
data_type_key="video")
def _get_vision_info(
image_processor,
vision_config: Qwen2VLVisionConfig,
height: int,
width: int,
min_pixels: int,
@@ -726,12 +654,15 @@ def _get_vision_info(
):
"""Get information (resized height / width and number of vision tokens)
of input image / video frame."""
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
if do_resize:
resized_height, resized_width = smart_resize(
height=height,
width=width,
factor=image_processor.patch_size * image_processor.merge_size,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
@@ -742,54 +673,41 @@ def _get_vision_info(
grid_t = mm_count
else:
assert data_type_key == "video"
grid_t = max(mm_count // image_processor.temporal_patch_size, 1)
grid_t = max(mm_count // temporal_patch_size, 1)
grid_h = resized_height // image_processor.patch_size
grid_w = resized_width // image_processor.patch_size
grid_h = resized_height // patch_size
grid_w = resized_width // patch_size
vision_tokens = grid_t * grid_h * grid_w
llm_num_vision_tokens = (vision_tokens // image_processor.merge_size //
image_processor.merge_size)
llm_num_vision_tokens = vision_tokens // (merge_size**2)
return resized_height, resized_width, llm_num_vision_tokens
def _get_max_image_info(
image_processor,
data_type_key: str = "image",
mm_count: int = 1,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
):
# Limit min / max pixels unless they're explicitly provided
if min_pixels is None:
min_pixels = max(image_processor.min_pixels, 28 * 28)
if max_pixels is None:
max_pixels = min(image_processor.max_pixels, 1280 * 28 * 28)
return _get_vision_info(
image_processor,
height=9999999,
width=9999999,
min_pixels=min_pixels,
max_pixels=max_pixels,
data_type_key=data_type_key,
mm_count=mm_count,
)
def _get_image_processor(hf_processor: Qwen2VLProcessor):
image_processor = hf_processor.image_processor # type: ignore
assert isinstance(image_processor, Qwen2VLImageProcessor)
return image_processor
def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
data_type_key: str,
*,
min_pixels=None,
max_pixels=None) -> int:
mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels,
max_pixels=max_pixels)
image_processor = cached_get_image_processor(ctx.model_config.model,
**mm_processor_kwargs)
max_resized_height, max_resized_width, max_llm_image_tokens = \
_get_max_image_info(image_processor, data_type_key=data_type_key,
mm_count=1, min_pixels=min_pixels,
max_pixels=max_pixels)
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None) -> int:
hf_config = ctx.get_hf_config(Qwen2VLConfig)
vision_config = hf_config.vision_config
hf_processor = ctx.get_hf_processor(Qwen2VLProcessor)
image_processor = _get_image_processor(hf_processor)
_, _, max_llm_image_tokens = _get_vision_info(
vision_config,
height=9999999,
width=9999999,
min_pixels=min_pixels or image_processor.min_pixels,
max_pixels=max_pixels or image_processor.max_pixels,
data_type_key=data_type_key,
)
return max_llm_image_tokens
@@ -799,290 +717,166 @@ get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="video")
def dummy_data_for_qwen2_vl(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
mm_processor_kwargs = get_mm_processor_kwargs(min_pixels=min_pixels,
max_pixels=max_pixels)
image_processor = cached_get_image_processor(ctx.model_config.model,
**mm_processor_kwargs)
class Qwen2VLMultiModalDataItems(MultiModalDataItems):
num_images = mm_counts["image"]
max_resized_height, max_resized_width, max_llm_image_tokens = \
_get_max_image_info(image_processor, data_type_key="image",
mm_count=num_images, min_pixels=min_pixels,
max_pixels=max_pixels)
if seq_len - max_llm_image_tokens - 2 < 0:
raise RuntimeError(
f"Qwen2-VL cannot process {num_images} images in a prompt, "
"please increase max_model_len or reduce image limit by "
"--limit-mm-per-prompt.")
@staticmethod
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = Qwen2VLMultiModalDataItems()
# Check video counts.
num_videos = mm_counts["video"]
max_resized_height, max_resized_width, max_llm_video_tokens = \
_get_max_image_info(image_processor, data_type_key="video",
mm_count=num_videos, min_pixels=min_pixels,
max_pixels=max_pixels)
if seq_len - max_llm_video_tokens - 2 < 0:
raise RuntimeError(
f"Qwen2-VL cannot process {num_videos} videos in a prompt, "
"please increase max_model_len or reduce video limit by "
"--limit-mm-per-prompt.")
for k, v in data.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if (isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
or is_list_of(v, list)) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (dict, torch.Tensor, list)) else [v]
)
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
hf_config = ctx.get_hf_config(Qwen2VLConfig)
return multi_data
dummy_seqdata = SequenceData.from_prompt_token_counts(
(hf_config.vision_start_token_id, 1),
(hf_config.image_token_id, max_llm_image_tokens),
(hf_config.vision_end_token_id, 1),
(0, seq_len - max_llm_image_tokens - 2),
)
dummy_image = Image.new("RGB", (max_resized_width, max_resized_height),
color=0)
return DummyData(dummy_seqdata, {
"image":
dummy_image if num_images == 1 else [dummy_image] * num_images
})
def get_item_counts(self) -> Mapping[str, int]:
return {
m: (
len(items[f"{m}_grid_thw"]) # type: ignore
if isinstance(items, dict) else len(items))
for m, items in self.items()
}
def _get_llm_num_vision_tokens(
mm_inputs: list,
data_type_key: str,
image_processor,
min_pixels: int,
max_pixels: int,
):
"""Get number of vision tokens of multimodal inputs.
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
This method is derived from `transformers.models.qwen2_vl.
image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`.
"""
image = to_numpy_array(mm_inputs[0])
input_data_format = infer_channel_dimension_format(image)
height, width = get_image_size(image, channel_dim=input_data_format)
def _get_mm_items(
self,
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
return Qwen2VLMultiModalDataItems.from_dict(mm_data)
_, _, llm_num_vision_tokens = _get_vision_info(
image_processor,
height=height,
width=width,
min_pixels=min_pixels,
max_pixels=max_pixels,
do_resize=image_processor.do_resize,
data_type_key=data_type_key,
mm_count=len(mm_inputs),
)
return llm_num_vision_tokens
def _get_hf_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
) -> Qwen2VLProcessor:
hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor)
image_processor = _get_image_processor(hf_processor)
if min_pixels:
image_processor.min_pixels = min_pixels
if max_pixels:
image_processor.max_pixels = max_pixels
if max_pixels or min_pixels:
image_processor.size = {
"min_pixels": image_processor.min_pixels,
"max_pixels": image_processor.max_pixels,
}
def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
data_type_key: str, image_processor: Any,
prompt_token_ids: List[int], min_pixels: Optional[int],
max_pixels: Optional[int]) -> List[int]:
"""
Expand pad tokens for multi-modal inputs (e.g., images or videos).
return hf_processor
Args:
inputs (list): The multi-modal inputs (e.g., images or videos).
token_id (int): The token ID used to represent the multi-modal input.
make_batched_fn (Callable): A function to batch the inputs.
data_type_key (str): The type of the multi-modal input.
image_processor (Any): The image processor used to process the inputs.
prompt_token_ids (List[int]): The list of token IDs in the prompt.
min_pixels (int): min pixels to used for img processing
max_pixels (int): max pixels to be used for img processing
def _get_processor_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
Returns:
List[int]: The list of token IDs for the multi-modal inputs.
"""
indices = [
idx for idx, token in enumerate(prompt_token_ids) if token == token_id
]
inputs = make_batched_fn(inputs)
assert len(indices) == len(inputs)
for k, v in mm_items.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
if k in ("image", "video", "audio"):
if isinstance(v, dict):
# Pass through embedding inputs (dict)
passthrough_data.update(v)
elif isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v]
elif (is_list_of(v, torch.Tensor) and len(v) > 0
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
else:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
processor_data[k] = v
prompt_token_ids_with_data = []
for cnt, data in enumerate(inputs):
num_tokens = _get_llm_num_vision_tokens(
[data] if data_type_key == "image" else data,
data_type_key=data_type_key,
image_processor=image_processor,
min_pixels=min_pixels,
max_pixels=max_pixels,
return processor_data, passthrough_data
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
# NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has
# image_token and video_token registered
placeholder = {
"image": hf_processor.image_token,
"video": hf_processor.video_token,
}
merge_length = image_processor.merge_size**2
def get_replacement_qwen2vl(item_idx: int, modality: str):
grid_thw = hf_inputs[f"{modality}_grid_thw"][item_idx]
num_tokens = grid_thw.prod() // merge_length
return placeholder[modality] * num_tokens
return [
PromptReplacement(
modality=modality,
target=placeholder[modality],
replacement=partial(get_replacement_qwen2vl,
modality=modality),
) for modality in ("image", "video")
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts["image"]
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token
image_processor = _get_image_processor(hf_processor)
data = {}
resized_height, resized_width = smart_resize(
height=9999999,
width=9999999,
factor=image_processor.patch_size * image_processor.merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
)
dummy_image = Image.new("RGB", (resized_width, resized_height),
color=0)
data["image"] = [dummy_image] * num_images
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
)
if cnt == 0:
end_idx = indices[cnt]
non_data_tokens = prompt_token_ids[:end_idx]
else:
non_data_tokens = prompt_token_ids[indices[cnt - 1] +
1:indices[cnt]]
prompt_token_ids_with_data.extend(non_data_tokens)
prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens))
prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:])
return prompt_token_ids_with_data
def input_processor_for_qwen2_vl(
ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None:
return inputs
image_inputs = multi_modal_data.get("image", None)
video_inputs = multi_modal_data.get("video", None)
processor = cached_get_processor(ctx.model_config.model)
image_processor = processor.image_processor
# Apply processor kwarg overrides for image processor options
min_pixels = min_pixels if min_pixels else image_processor.min_pixels
max_pixels = max_pixels if max_pixels else image_processor.max_pixels
model_config = ctx.model_config
hf_config = ctx.get_hf_config(Qwen2VLConfig)
# To avoid redundant processing of vision objects (resize, rescale, etc.),
# we extract code of calculating number of vision tokens from
# `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`.
#
# The following code is equivalent to:
# prompt = inputs["prompt"]
# inputs = processor(text=[prompt],
# images=image_inputs,
# videos=video_inputs,
# padding=True,
# return_tensors="pt")
# prompt_token_ids = inputs["input_ids"][0].tolist()
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
prompt_token_ids = inputs["prompt_token_ids"]
# Expand image pad tokens.
if image_inputs is not None:
if isinstance(image_inputs, dict):
prompt_token_ids_with_image = []
image_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.image_token_id
]
# ensure all image tokens have grid_thw
assert \
len(image_indices) == image_inputs["image_grid_thw"].size(0), \
"image token num does not match image_grid_thw.shape"
image_counter = 0
pad_token_counter = 0
for idx, token in enumerate(prompt_token_ids):
if idx in image_indices:
grid_thw = image_inputs["image_grid_thw"][image_counter]
grid_t, grid_h, grid_w = grid_thw
num_pad_tokens = (grid_t * grid_h * grid_w //
image_processor.merge_size //
image_processor.merge_size)
prompt_token_ids_with_image.extend([token] *
num_pad_tokens)
image_counter += 1
pad_token_counter += num_pad_tokens
else:
prompt_token_ids_with_image.append(token)
# ensure all embeddings are used
assert \
pad_token_counter == image_inputs["image_embeds"].size(0), \
"image_embeds.shape does not match image_grid_thw"
prompt_token_ids = prompt_token_ids_with_image
else:
prompt_token_ids = _expand_pad_tokens(image_inputs,
hf_config.image_token_id,
make_batched_images,
"image",
image_processor,
prompt_token_ids,
min_pixels=min_pixels,
max_pixels=max_pixels)
if video_inputs is not None:
if isinstance(video_inputs, dict):
prompt_token_ids_with_video = []
video_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.video_token_id
]
# ensure all video tokens have grid_thw
assert \
len(video_indices) == video_inputs["video_grid_thw"].size(0), \
"video token num does not match video_grid_thw.shape"
video_counter = 0
pad_token_counter = 0
for idx, token in enumerate(prompt_token_ids):
if idx in video_indices:
grid_thw = video_inputs["video_grid_thw"][video_counter]
grid_t, grid_h, grid_w = grid_thw
num_pad_tokens = (grid_t * grid_h * grid_w //
image_processor.merge_size //
image_processor.merge_size)
prompt_token_ids_with_video.extend([token] *
num_pad_tokens)
video_counter += 1
pad_token_counter += num_pad_tokens
else:
prompt_token_ids_with_video.append(token)
# ensure all embeddings are used
assert \
pad_token_counter == video_inputs["video_embeds"].size(0), \
"video_embeds.shape does not match video_grid_thw"
prompt_token_ids = prompt_token_ids_with_video
else:
prompt_token_ids = _expand_pad_tokens(video_inputs,
hf_config.video_token_id,
make_batched_videos,
"video",
image_processor,
prompt_token_ids,
min_pixels=min_pixels,
max_pixels=max_pixels)
prompt = inputs.get("prompt")
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
return token_inputs(
prompt_token_ids=prompt_token_ids,
prompt=prompt,
multi_modal_data=multi_modal_data,
)
@MULTIMODAL_REGISTRY.register_image_input_mapper(
image_input_mapper_for_qwen2_vl)
@MULTIMODAL_REGISTRY.register_input_mapper("video",
video_input_mapper_for_qwen2_vl)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_qwen2_vl_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl)
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
packed_modules_mapping = {
@@ -1110,7 +904,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
config: Qwen2VLConfig = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config