[VLM] Support caching in merged multi-modal processor (#11396)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-28 01:22:48 +08:00
committed by GitHub
parent 5ce4627a7e
commit 101418096f
20 changed files with 1459 additions and 452 deletions

View File

@@ -1,5 +1,4 @@
from functools import cached_property
from types import MethodType
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, TypedDict, Union)
@@ -7,7 +6,7 @@ import torch
import torch.nn as nn
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
ProcessorMixin, SiglipVisionConfig)
SiglipVisionConfig)
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
@@ -21,10 +20,12 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
ProcessorInputs, PromptReplacement,
full_groupby_modality)
from vllm.sequence import IntermediateTensors
from .clip import (CLIPVisionModel, dummy_image_for_clip,
@@ -116,36 +117,54 @@ def get_max_llava_image_tokens(ctx: InputContext):
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
if getattr(hf_processor, "__is_patched__", False):
return # Already patched
image_processor = hf_processor.image_processor # type: ignore
orig_preprocess = image_processor.preprocess
def preprocess(__self, *args, **kwargs):
hf_inputs = orig_preprocess(*args, **kwargs)
hf_inputs["is_pixtral"] = torch.tensor(True)
return hf_inputs
image_processor.preprocess = MethodType(preprocess, image_processor)
hf_processor.__is_patched__ = True # type: ignore
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
hf_processor = self.ctx.get_hf_processor(
(LlavaProcessor, PixtralProcessor))
return self.ctx.get_hf_processor((LlavaProcessor, PixtralProcessor))
if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
return hf_processor
# NOTE: pixel_values=None for MLlavaProcessor
pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
images = mm_data["images"]
assert isinstance(images, list)
if isinstance(self._get_hf_processor(), PixtralProcessor):
# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list)
and len(pixel_values) == 1
and isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))
processed_outputs["pixel_values"] = pixel_values[0]
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
@@ -200,7 +219,7 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
num_images = mm_counts.get("image", 0)
if isinstance(vision_config, CLIPVisionConfig):
data = dummy_image_for_clip(vision_config, num_images)
@@ -218,7 +237,6 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
)
@@ -379,7 +397,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
@@ -390,33 +407,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
assert isinstance(is_pixtral, torch.Tensor)
if is_pixtral.any():
images = pixel_values
def flatten_to_3d_tensors(item):
if isinstance(item, torch.Tensor):
if item.dim() >= 3:
return [t for t in item.view(-1, *item.shape[-3:])]
else:
raise ValueError(
f"Unexpected tensor dimension: {item.dim()}")
elif isinstance(item, list):
return [
t for subitem in item
for t in flatten_to_3d_tensors(subitem)
]
else:
raise ValueError(f"Unexpected type: {type(item)}")
# Restructure the batched images into a list of lists of images
images = flatten_to_3d_tensors(pixel_values)
return LlavaImagePixelInputs(
type="pixel_values",
data=images,
)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
@@ -586,19 +576,71 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def _get_hf_processor(self) -> ProcessorMixin:
try:
from mantis.models.mllava import MLlavaProcessor
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"You need to `pip install "
"git+https://github.com/TIGER-AI-Lab/Mantis.git` "
"to use this model") from exc
def _get_hf_processor(self):
return self.ctx.get_hf_processor(LlavaProcessor)
processor = MLlavaProcessor.from_pretrained(
self.ctx.model_config.tokenizer)
assert isinstance(processor, ProcessorMixin)
return processor
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
max_image_tokens = get_max_llava_image_tokens(self.ctx)
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
mm_items = self._get_mm_items(mm_data)
mm_item_counts = mm_items.get_item_counts()
mm_kwargs = result["mm_kwargs"]
# We reimplement the functionality of MLlavaProcessor from
# https://github.com/TIGER-AI-Lab/Mantis.git
def get_replacement_mantis(item_idx: int):
return "".join([
f"(image {item_idx+1}: <Image>", # 7 tokens
"<image>" * max_image_tokens,
"</Image>)", # 3 tokens
])
mantis_repls = self._bind_prompt_replacements([
PromptReplacement(
modality="image",
target=[image_token_id] * max_image_tokens,
replacement=get_replacement_mantis,
)
])
prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
result["prompt_token_ids"],
mantis_repls,
mm_item_counts,
)
unbound_orig_repls = self._get_prompt_replacements(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
orig_repls = self._bind_prompt_replacements(unbound_orig_repls)
all_placeholders = self._find_placeholders(orig_repls, prompt_ids,
mm_item_counts)
assert len(all_placeholders) == mm_item_counts.get("image", 0)
mm_placeholders = {
modality: [item.to_range() for item in items]
for modality, items in full_groupby_modality(all_placeholders)
}
return MultiModalInputsV2(
type="multimodal",
prompt=prompt_text,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders,
)
# To use this model, please use

View File

@@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@@ -32,10 +32,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalInputsV2,
MultiModalKwargs, NestedTensors,
PlaceholderRange)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
ProcessorInputs, PromptReplacement,
_BoundPromptReplacement,
_PlaceholderInfo)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@@ -306,11 +310,11 @@ def get_max_phi3v_image_tokens(
*,
num_crops: Optional[int] = None,
) -> int:
mm_processor_kwargs = {}
hf_processor_mm_kwargs = {}
if num_crops:
mm_processor_kwargs["num_crops"] = num_crops
hf_processor_mm_kwargs["num_crops"] = num_crops
processor = ctx.get_hf_processor(**mm_processor_kwargs)
processor = ctx.get_hf_processor(**hf_processor_mm_kwargs)
return processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
@@ -331,39 +335,50 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
input_ids = processed_outputs["input_ids"]
assert isinstance(input_ids, torch.Tensor)
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
# which will cause OverflowError when decoding the prompt_ids.
# Therefore, we need to do an early replacement here
token_ids = processed_outputs['input_ids']
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
processed_outputs['input_ids'] = token_ids
input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
image_processor = hf_processor.image_processor # type: ignore
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
tokenizer = self._get_tokenizer()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
@@ -372,21 +387,44 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
height=image_size.height,
)
return [_IMAGE_TOKEN_ID] * num_tokens
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:max_images]
) for image_token in image_tokens[:len(mm_items.images)]
]
def _apply_prompt_replacements(
self,
token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids=token_ids,
prompt_repls=prompt_repls,
mm_item_counts=mm_item_counts,
)
# Keep the behavior in line with HF processor
if text.startswith("<s> <|image|>"):
text = text.replace("<s> <|image|>", "<s><|image|>", 1)
token_ids = [token_ids[0], *token_ids[2:]]
placeholders = [
_PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement)
for p in placeholders
]
return token_ids, text, placeholders
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts["image"]
num_images = mm_counts.get("image", 0)
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
@@ -401,9 +439,28 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=data,
mm_processor_kwargs={},
)
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
# Only <|image|> tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)

View File

@@ -225,7 +225,7 @@ class VisualAttentionBlock(nn.Module):
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
norm_layer: Callable = nn.LayerNorm,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -266,7 +266,7 @@ class TransformerBlock(nn.Module):
layers: int,
heads: int,
mlp_ratio: float = 4.0,
norm_layer: Callable = nn.LayerNorm,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()

View File

@@ -26,7 +26,7 @@ from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple,
import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature, ProcessorMixin
from transformers import BatchFeature
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
Qwen2AudioEncoder,
Qwen2AudioProcessor)
@@ -38,10 +38,10 @@ from vllm.inputs import InputContext
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
ProcessorInputs, PromptReplacement)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
@@ -73,7 +73,7 @@ class Qwen2AudioMultiModalProjector(nn.Module):
# From Qwen2AudioEncoder._get_feat_extract_output_lengths
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
feat_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (feat_lengths - 2) // 2 + 1
return feat_lengths, output_lengths
@@ -88,13 +88,18 @@ def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
def _get_hf_processor(self) -> Qwen2AudioProcessor:
def _get_hf_processor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor)
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().feature_extractor # type: ignore
def _get_processor_data(
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -102,50 +107,61 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)
return super()._get_processor_data(mm_items)
return super()._get_hf_mm_data(mm_items)
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])
if audios:
processor_data["audios"] = audios
mm_data["audios"] = audios
feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
else:
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
pass
return super()._call_hf_processor(
hf_processor,
processed_outputs = super()._call_hf_processor(
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
input_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
placeholder = hf_config.audio_token_index
feature_attention_mask = hf_inputs.get("feature_attention_mask")
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
if feature_attention_mask is None:
audio_output_lengths = []
else:
assert isinstance(feature_attention_mask, torch.Tensor)
_, audio_output_lengths = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1))
@@ -168,14 +184,13 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
audio_count = mm_counts["audio"]
audio_count = mm_counts.get("audio", 0)
audio = np.zeros(audio_len)
data = {"audio": [audio] * audio_count}
return ProcessorInputs(
prompt_text="<|AUDIO|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
)

View File

@@ -22,9 +22,10 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import cached_property, partial
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, Type, TypedDict, Union)
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Set, Tuple, Type, TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -54,10 +55,11 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems,
MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
ProcessorInputs, PromptReplacement)
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
@@ -229,9 +231,9 @@ class Qwen2VisionAttention(nn.Module):
def __init__(
self,
embed_dim: Optional[int] = None,
num_heads: Optional[int] = None,
projection_size: Optional[int] = None,
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
@@ -264,7 +266,7 @@ class Qwen2VisionAttention(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None,
rotary_pos_emb: torch.Tensor,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -347,7 +349,7 @@ class Qwen2VisionBlock(nn.Module):
num_heads: int,
mlp_ratio: float,
act_layer: Type[nn.Module] = QuickGELU,
norm_layer: Type[nn.Module] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
@@ -384,7 +386,7 @@ class Qwen2VisionPatchEmbed(nn.Module):
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_chans: int = 3,
in_channels: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
@@ -392,8 +394,8 @@ class Qwen2VisionPatchEmbed(nn.Module):
self.temporal_patch_size = temporal_patch_size
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(in_chans,
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(in_channels,
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
@@ -413,7 +415,7 @@ class Qwen2VisionPatchMerger(nn.Module):
self,
d_model: int,
context_dim: int,
norm_layer: Type[nn.Module] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@@ -489,15 +491,15 @@ class Qwen2VisionTransformer(nn.Module):
) -> None:
super().__init__()
patch_size: int = vision_config.patch_size
temporal_patch_size: int = vision_config.temporal_patch_size
spatial_merge_size: int = vision_config.spatial_merge_size
in_chans: int = vision_config.in_chans
hidden_size: int = vision_config.hidden_size
embed_dim: int = vision_config.embed_dim
depth: int = vision_config.depth
num_heads: int = vision_config.num_heads
mlp_ratio: float = vision_config.mlp_ratio
patch_size = vision_config.patch_size
temporal_patch_size = vision_config.temporal_patch_size
spatial_merge_size = vision_config.spatial_merge_size
in_channels = vision_config.in_channels
hidden_size = vision_config.hidden_size
embed_dim = vision_config.embed_dim
depth = vision_config.depth
num_heads = vision_config.num_heads
mlp_ratio = vision_config.mlp_ratio
self.spatial_merge_size = spatial_merge_size
self.num_heads = num_heads
@@ -506,7 +508,7 @@ class Qwen2VisionTransformer(nn.Module):
self.patch_embed = Qwen2VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_chans=in_chans,
in_channels=in_channels,
embed_dim=embed_dim,
)
@@ -733,8 +735,12 @@ class Qwen2VLMultiModalDataItems(MultiModalDataItems):
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if (isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
or is_list_of(v, list)) else [v]
v if (
isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
or is_list_of(v, list)
or isinstance(v[0], (np.ndarray, torch.Tensor))
and v[0].ndim == 4
) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
@@ -754,6 +760,12 @@ class Qwen2VLMultiModalDataItems(MultiModalDataItems):
for m, items in self.items()
}
def has_embedding_inputs(self) -> bool:
return any(
isinstance(items, dict) or any(
isinstance(item, torch.Tensor) for item in items)
for items in self.values())
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
@@ -784,7 +796,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
return hf_processor
def _get_processor_data(
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -805,7 +817,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
else:
elif len(v) > 0:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
@@ -816,8 +828,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
@@ -831,7 +843,9 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
merge_length = image_processor.merge_size**2
def get_replacement_qwen2vl(item_idx: int, modality: str):
grid_thw = hf_inputs[f"{modality}_grid_thw"][item_idx]
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
assert isinstance(grid_thw, torch.Tensor)
num_tokens = grid_thw.prod() // merge_length
return placeholder[modality] * num_tokens
@@ -844,11 +858,40 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
) for modality in ("image", "video")
]
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_slice_idxs = [0] + image_grid_thw.prod(-1).cumsum_(0).tolist()
image_slices = [
slice(image_slice_idxs[i], image_slice_idxs[i + 1])
for i in range(len(image_grid_thw))
]
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_slice_idxs = [0] + video_grid_thw.prod(-1).cumsum_(0).tolist()
video_slices = [
slice(video_slice_idxs[i], video_slice_idxs[i + 1])
for i in range(len(video_grid_thw))
]
return dict(
pixel_values=MultiModalFieldConfig.flat("image", image_slices),
image_embeds=MultiModalFieldConfig.flat("image", image_slices),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat(
"video", video_slices),
video_embeds=MultiModalFieldConfig.flat("video", video_slices),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts["image"]
num_images = mm_counts.get("image", 0)
hf_processor = self._get_hf_processor()
image_token: str = hf_processor.image_token
image_processor = _get_image_processor(hf_processor)
@@ -869,7 +912,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
)
@@ -950,9 +992,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
return quant_config
def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor,
List[torch.Tensor]],
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
@@ -962,7 +1002,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim}")
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return torch.concat(list(mm_input))
else:
return torch.concat(mm_input)

View File

@@ -23,10 +23,11 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
ProcessorInputs, PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
@@ -72,11 +73,19 @@ def get_ultravox_max_audio_tokens(ctx: InputContext):
class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
def _get_hf_processor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> ProcessorMixin:
return self.ctx.get_hf_processor()
def _get_feature_extractor(self) -> WhisperFeatureExtractor:
hf_processor = self._get_hf_processor()
return hf_processor.audio_processor.feature_extractor # type: ignore
def _get_processor_data(
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
@@ -84,33 +93,41 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate)
return super()._get_processor_data(mm_items)
return super()._get_hf_mm_data(mm_items)
def _call_hf_processor(
self,
hf_processor: ProcessorMixin,
prompt: str,
processor_data: Mapping[str, object],
mm_processor_kwargs: Mapping[str, object],
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processor_data = dict(processor_data)
audios = processor_data.pop("audios", [])
# Text-only input not supported in composite processor
if not mm_data:
tokenizer = self._get_tokenizer()
prompt_ids = tokenizer.encode(
prompt,
add_special_tokens=False, # type: ignore
)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])
if not audios:
return super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
feature_extractor = self._get_feature_extractor()
mm_processor_kwargs = dict(
**mm_processor_kwargs,
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
# Already resampled by _get_processor_data
# Already resampled by _get_hf_mm_data
assert is_list_of(audios, np.ndarray)
# Ultravox processor doesn't support multiple inputs,
@@ -119,13 +136,12 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
shared_outputs = {}
for audio in audios:
# NOTE: Ultravox processor accepts "audio" instead of "audios"
item_processor_data = dict(**processor_data, audio=audio)
item_processor_data = dict(**mm_data, audio=audio)
item_outputs = super()._call_hf_processor(
hf_processor,
prompt=prompt,
processor_data=item_processor_data,
mm_processor_kwargs=mm_processor_kwargs,
mm_data=item_processor_data,
mm_kwargs=mm_kwargs,
)
audio_features.append(item_outputs.pop("audio_values")[0])
@@ -139,17 +155,28 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
)
return BatchFeature(combined_outputs)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
audio_features=MultiModalFieldConfig.batched("audio"),
audio_token_len=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
placeholder = hf_processor.audio_token_replacement # type: ignore
def get_replacement_ultravox(item_idx: int):
audio_token_len = hf_inputs["audio_token_len"][item_idx]
audio_token_len = out_mm_kwargs["audio_token_len"][item_idx]
return placeholder * audio_token_len
return [
@@ -168,14 +195,13 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
audio_count = mm_counts["audio"]
audio_count = mm_counts.get("audio", 0)
audio = np.zeros(audio_len)
data = {"audio": [audio] * audio_count}
return ProcessorInputs(
prompt_text="<|audio|>" * audio_count,
mm_data=data,
mm_processor_kwargs={},
)