Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -31,11 +31,15 @@ import torch
import torch.nn as nn
from transformers.feature_extraction_utils import BatchFeature
from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
Qwen2_5OmniConfig, Qwen2_5OmniThinkerConfig)
Qwen2_5OmniConfig,
Qwen2_5OmniThinkerConfig,
)
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import (
Qwen2_5OmniAudioEncoder)
Qwen2_5OmniAudioEncoder,
)
from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import (
Qwen2_5OmniProcessor)
Qwen2_5OmniProcessor,
)
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig
@@ -44,33 +48,60 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
Qwen2_5_VisionTransformer,
Qwen2_5_VLImageEmbeddingInputs,
Qwen2_5_VLImageInputs,
Qwen2_5_VLImagePixelInputs,
Qwen2_5_VLProcessingInfo,
Qwen2_5_VLVideoEmbeddingInputs,
Qwen2_5_VLVideoInputs,
Qwen2_5_VLVideoPixelInputs,
)
from vllm.model_executor.models.qwen2_audio import (
Qwen2AudioProcessingInfo, _get_feat_extract_output_lengths)
Qwen2AudioProcessingInfo,
_get_feat_extract_output_lengths,
)
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate)
from vllm.multimodal.inputs import (
ImageItem,
ModalityData,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
NestedTensors,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
DictEmbeddingItems,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
PromptReplacement,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import encode_tokens
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
try:
import flash_attn
@@ -88,6 +119,7 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
- msl: Maximum sequence length
- tsl: Total sequence length
"""
type: Literal["audio_features"]
input_features: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
@@ -101,52 +133,55 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
def create_qwen2_5_omni_thinker_field_factory(
spatial_merge_size: int
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
MultiModalFieldConfig]]:
def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str,
torch.Tensor]):
audio_feature_lengths = hf_inputs.get("audio_feature_lengths",
torch.empty((0, )))
spatial_merge_size: int,
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str, MultiModalFieldConfig]]:
def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_feature_lengths = hf_inputs.get(
"audio_feature_lengths", torch.empty((0,))
)
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_pixel_grid_sizes = image_grid_thw.prod(-1)
image_embed_grid_sizes = (image_pixel_grid_sizes //
spatial_merge_size // spatial_merge_size)
image_embed_grid_sizes = (
image_pixel_grid_sizes // spatial_merge_size // spatial_merge_size
)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size //
spatial_merge_size)
video_embed_grid_sizes = (
video_grid_sizes // spatial_merge_size // spatial_merge_size
)
num_videos = len(video_grid_sizes)
return dict(
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_feature_lengths, dim=1),
"audio", audio_feature_lengths, dim=1
),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_pixel_grid_sizes),
"image", image_pixel_grid_sizes
),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_embed_grid_sizes),
"image", image_embed_grid_sizes
),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
"video", video_grid_sizes
),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_embed_grid_sizes),
"video", video_embed_grid_sizes
),
video_grid_thw=MultiModalFieldConfig.batched("video"),
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
use_audio_in_video=MultiModalFieldConfig.shared(
"video", num_videos),
use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos),
)
return _qwen2_5_omni_thinker_field_config
class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
def __init__(self, spatial_merge_size: int, *args, **kwargs):
self._spatial_merge_size = spatial_merge_size
super().__init__(self._spatial_merge_size, *args, **kwargs)
@@ -159,19 +194,18 @@ class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={
"input_audio_features", "audio_feature_lengths"
},
required_fields={"input_audio_features", "audio_feature_lengths"},
fields_factory=create_qwen2_5_omni_thinker_field_factory(
self._spatial_merge_size),
self._spatial_merge_size
),
)
return super()._parse_audio_data(data)
class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo,
Qwen2_5_VLProcessingInfo):
class Qwen2_5OmniThinkerProcessingInfo(
Qwen2AudioProcessingInfo, Qwen2_5_VLProcessingInfo
):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config
@@ -193,8 +227,8 @@ class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo,
class Qwen2_5OmniThinkerDummyInputsBuilder(
BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]):
BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]
):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
num_images = mm_counts.get("image", 0)
@@ -206,8 +240,11 @@ class Qwen2_5OmniThinkerDummyInputsBuilder(
image_token: str = hf_processor.image_token
video_token: str = hf_processor.video_token
return (audio_token * num_audios + image_token * num_images +
video_token * num_videos)
return (
audio_token * num_audios
+ image_token * num_images
+ video_token * num_videos
)
def get_dummy_mm_data(
self,
@@ -221,49 +258,55 @@ class Qwen2_5OmniThinkerDummyInputsBuilder(
feature_extractor = self.info.get_feature_extractor()
target_audio_length = min(
feature_extractor.chunk_length,
30,
) * feature_extractor.sampling_rate
target_width, target_height = \
self.info.get_image_size_with_most_features()
target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
target_audio_length = (
min(
feature_extractor.chunk_length,
30,
)
* feature_extractor.sampling_rate
)
target_width, target_height = self.info.get_image_size_with_most_features()
target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts
)
image_overrides = mm_options.get("image") if mm_options else None
video_overrides = mm_options.get("video") if mm_options else None
audio_overrides = mm_options.get("audio") if mm_options else None
mm_data = {
"audio":
self._get_dummy_audios(length=target_audio_length,
num_audios=num_audios,
overrides=audio_overrides),
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides),
"video":
self._get_dummy_videos(width=target_width,
height=target_height,
num_frames=target_num_frames,
num_videos=num_videos,
overrides=video_overrides),
"audio": self._get_dummy_audios(
length=target_audio_length,
num_audios=num_audios,
overrides=audio_overrides,
),
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides,
),
"video": self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=target_num_frames,
num_videos=num_videos,
overrides=video_overrides,
),
}
return mm_data
class Qwen2_5OmniThinkerMultiModalProcessor(
BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]):
BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return Qwen2_5OmniThinkerMultiModalDataParser(
spatial_merge_size=self.info.get_hf_config(
).vision_config.spatial_merge_size,
target_sr=feature_extractor.sampling_rate)
spatial_merge_size=self.info.get_hf_config().vision_config.spatial_merge_size,
target_sr=feature_extractor.sampling_rate,
)
def _call_hf_processor(
self,
@@ -279,7 +322,9 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
if audios:
# NOTE: Qwen2.5-Omni processor accept "audio"
mm_data["audio"] = audios
mm_kwargs = dict(**mm_kwargs, )
mm_kwargs = dict(
**mm_kwargs,
)
hf_inputs = super()._call_hf_processor(
prompt=prompt,
@@ -288,17 +333,19 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
tok_kwargs=tok_kwargs,
)
input_features = hf_inputs.pop('input_features', None)
feature_attention_mask = hf_inputs.get('feature_attention_mask', None)
if ('input_audio_features' not in hf_inputs
and input_features is not None):
input_features = hf_inputs.pop("input_features", None)
feature_attention_mask = hf_inputs.get("feature_attention_mask", None)
if "input_audio_features" not in hf_inputs and input_features is not None:
if feature_attention_mask is not None:
input_features = input_features.permute(
0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
hf_inputs['input_audio_features'] = input_features
if ('audio_feature_lengths' not in hf_inputs
and feature_attention_mask is not None):
hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1)
input_features = input_features.permute(0, 2, 1)[
feature_attention_mask.bool()
].permute(1, 0)
hf_inputs["input_audio_features"] = input_features
if (
"audio_feature_lengths" not in hf_inputs
and feature_attention_mask is not None
):
hf_inputs["audio_feature_lengths"] = feature_attention_mask.sum(-1)
video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
if video_second_per_grid is not None:
@@ -315,8 +362,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return create_qwen2_5_omni_thinker_field_factory(
self.info.get_hf_config().vision_config.spatial_merge_size)(
hf_inputs)
self.info.get_hf_config().vision_config.spatial_merge_size
)(hf_inputs)
def _maybe_apply_prompt_updates(
self,
@@ -335,13 +382,12 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
use_audio_in_video = False
if "video" in mm_kwargs:
video_items = [
item for item in mm_kwargs["video"] if item is not None
]
video_items = [item for item in mm_kwargs["video"] if item is not None]
# only check video items (if there are any)
if video_items:
use_audio_in_video = all(item["use_audio_in_video"].data
for item in video_items)
use_audio_in_video = all(
item["use_audio_in_video"].data for item in video_items
)
if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
@@ -374,8 +420,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
vocab = tokenizer.get_vocab()
audio_token = processor.audio_token
@@ -392,12 +437,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
audio_output_lengths = []
elif audio_feature_lengths is not None:
_, audio_output_lens = _get_feat_extract_output_lengths(
audio_feature_lengths)
audio_feature_lengths
)
audio_output_lengths = audio_output_lens.tolist()
elif feature_attention_mask is not None:
assert isinstance(feature_attention_mask, torch.Tensor)
_, audio_output_lens = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1))
feature_attention_mask.sum(-1)
)
audio_output_lengths = audio_output_lens.tolist()
# number of audios read from video.
@@ -412,7 +459,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
audio = audios.get(item_idx)
raise ValueError(
f"The audio {audio} (len={len(audio)}) is too short "
"to be represented inside the model")
"to be represented inside the model"
)
return [audio_token_id] * num_features
@@ -424,21 +472,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
token_id = image_token_id if modality == "image" else video_token_id
return [token_id] * (int(grid_thw.prod()) // merge_length)
use_audio_in_video = hf_processor_mm_kwargs.get(
"use_audio_in_video", False)
use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False)
thinker_config = self.info.get_hf_config()
def get_replacement_qwen2_use_audio_in_video(item_idx: int):
nonlocal audio_in_video_item_idx
audio_num_features = audio_output_lengths[audio_in_video_item_idx +
item_idx]
audio_num_features = audio_output_lengths[
audio_in_video_item_idx + item_idx
]
video_grid_thw = out_mm_data["video_grid_thw"][item_idx]
audio_in_video_item_idx += 1
second_per_grid_ts = hf_processor_mm_kwargs.get(
"second_per_grid_ts", None)
second_per_grid_ts = hf_processor_mm_kwargs.get("second_per_grid_ts", None)
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[item_idx]
else:
@@ -452,8 +499,10 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
)
video_replacement_fn = (
get_replacement_qwen2_use_audio_in_video if use_audio_in_video else
partial(get_replacement_qwen2_vision, modality="video"))
get_replacement_qwen2_use_audio_in_video
if use_audio_in_video
else partial(get_replacement_qwen2_vision, modality="video")
)
return [
PromptReplacement(
@@ -464,8 +513,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
PromptReplacement(
modality="image",
target=image_token,
replacement=partial(get_replacement_qwen2_vision,
modality="image"),
replacement=partial(get_replacement_qwen2_vision, modality="image"),
),
PromptReplacement(
modality="video",
@@ -518,8 +566,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
"""
mm_counts = mm_items.get_all_counts()
use_audio_in_video = hf_processor_mm_kwargs.get(
"use_audio_in_video", False)
use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False)
if use_audio_in_video and "video" in mm_counts:
assert "audio" in mm_counts
mm_counts["audio"] -= mm_counts["video"]
@@ -548,14 +595,11 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
class Qwen2_5OmniConditionalGenerationMixin:
def _validate_and_reshape_mm_tensor(self,
mm_input: object,
name: str,
dim: int = 0) -> torch.Tensor:
def _validate_and_reshape_mm_tensor(
self, mm_input: object, name: str, dim: int = 0
) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if dim == 0:
return mm_input.reshape(-1, *mm_input.shape[2:])
@@ -564,25 +608,31 @@ class Qwen2_5OmniConditionalGenerationMixin:
return torch.concat(mm_input, dim=dim)
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
input_audio_features = kwargs.pop('input_audio_features', None)
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
self, **kwargs: object
) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
input_audio_features = kwargs.pop("input_audio_features", None)
audio_feature_lengths = kwargs.pop("audio_feature_lengths", None)
feature_attention_mask = kwargs.pop("feature_attention_mask", None)
if input_audio_features is None:
return None
input_audio_features = self._validate_and_reshape_mm_tensor(
input_audio_features, 'input_audio_features', dim=1)
input_audio_features, "input_audio_features", dim=1
)
if feature_attention_mask is not None:
feature_attention_mask = self._validate_and_reshape_mm_tensor(
feature_attention_mask, 'feature_attention_mask')
feature_attention_mask, "feature_attention_mask"
)
if not isinstance(input_audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_audio_features)}")
raise ValueError(
"Incorrect type of audio input features. "
f"Got type: {type(input_audio_features)}"
)
return Qwen2_5OmniAudioFeatureInputs(
type="audio_features",
input_features=input_audio_features,
audio_feature_lengths=audio_feature_lengths,
feature_attention_mask=feature_attention_mask)
feature_attention_mask=feature_attention_mask,
)
def _parse_and_validate_image_input(
self,
@@ -597,31 +647,42 @@ class Qwen2_5OmniConditionalGenerationMixin:
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
pixel_values, "image pixel values"
)
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
image_grid_thw, "image grid_thw"
)
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
raise ValueError(
"Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}"
)
return Qwen2_5_VLImagePixelInputs(type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw)
return Qwen2_5_VLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_embeds, "image embeds"
)
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
image_grid_thw, "image grid_thw"
)
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
raise ValueError(
"Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}"
)
return Qwen2_5_VLImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw)
image_grid_thw=image_grid_thw,
)
def _parse_and_validate_video_input(
self,
@@ -636,9 +697,11 @@ class Qwen2_5OmniConditionalGenerationMixin:
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos, "video pixel values")
pixel_values_videos, "video pixel values"
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
video_grid_thw, "video grid_thw"
)
return Qwen2_5_VLVideoPixelInputs(
type="pixel_values_videos",
@@ -648,17 +711,22 @@ class Qwen2_5OmniConditionalGenerationMixin:
if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, "video embeds")
video_embeds, "video embeds"
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
video_grid_thw, "video grid_thw"
)
if not isinstance(video_embeds, torch.Tensor):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
raise ValueError(
"Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}"
)
return Qwen2_5_VLVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw)
video_grid_thw=video_grid_thw,
)
def _process_audio_input(
self,
@@ -666,35 +734,35 @@ class Qwen2_5OmniConditionalGenerationMixin:
audio_hashes: list[str] = None,
cached_audio_features: torch.Tensor = None,
) -> torch.Tensor:
input_features = audio_input["input_features"]
audio_feature_lengths = audio_input["audio_feature_lengths"]
if input_features.ndim == 3:
assert input_features.shape[0] == 1
input_features = input_features.squeeze(0)
if audio_feature_lengths.ndim == 2:
assert audio_feature_lengths.shape[
0] == 1 or audio_feature_lengths.shape[1] == 1
assert (
audio_feature_lengths.shape[0] == 1
or audio_feature_lengths.shape[1] == 1
)
if audio_feature_lengths.shape[0] == 1:
audio_feature_lengths = audio_feature_lengths.squeeze(0)
else:
audio_feature_lengths = audio_feature_lengths.squeeze(1)
audio_feat_lengths, audio_output_lengths = (
self.audio_tower._get_feat_extract_output_lengths(
audio_feature_lengths))
self.audio_tower._get_feat_extract_output_lengths(audio_feature_lengths)
)
audio_outputs = self.audio_tower(
input_features.to(self.audio_tower.dtype),
feature_lens=audio_feature_lengths,
aftercnn_lens=audio_feat_lengths,
)
return audio_outputs.last_hidden_state.split(
audio_output_lengths.tolist())
return audio_outputs.last_hidden_state.split(audio_output_lengths.tolist())
def _process_image_input(
self,
image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
self, image_input: Qwen2_5_VLImageInputs
) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
return image_input["image_embeds"].type(self.visual.dtype)
@@ -710,18 +778,18 @@ class Qwen2_5OmniConditionalGenerationMixin:
return image_embeds.split(sizes.tolist())
def _process_video_input(
self,
video_input: Qwen2_5_VLVideoInputs,
video_hashes: list[str] = None,
cached_video_embeds: torch.Tensor = None) -> torch.Tensor:
self,
video_input: Qwen2_5_VLVideoInputs,
video_hashes: list[str] = None,
cached_video_embeds: torch.Tensor = None,
) -> torch.Tensor:
if video_input["type"] == "video_embeds":
return video_input["video_embeds"].type(self.visual.dtype)
grid_thw = video_input["video_grid_thw"]
assert grid_thw.ndim == 2
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype)
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
# Split concatenated embeddings for each video item.
merge_size = self.visual.spatial_merge_size
@@ -736,14 +804,19 @@ class Qwen2_5OmniConditionalGenerationMixin:
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
)
class Qwen2_5OmniThinkerForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
Qwen2_5OmniConditionalGenerationMixin):
nn.Module,
SupportsMultiModal,
SupportsPP,
SupportsLoRA,
Qwen2_5OmniConditionalGenerationMixin,
):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"thinker.lm_head.": "language_model.lm_head.",
"thinker.model.": "language_model.model.",
"thinker.": "",
})
}
)
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -775,7 +848,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
thinker_config: Qwen2_5OmniThinkerConfig = (
vllm_config.model_config.hf_config.thinker_config)
vllm_config.model_config.hf_config.thinker_config
)
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = thinker_config
@@ -791,20 +865,20 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
logger.warning(
"flash_attn is not available, the model may not yield the "
"exactly same result as the transformers implementation "
"in the audio tower part.")
"in the audio tower part."
)
if multimodal_config.get_limit_per_prompt("audio"):
self.audio_tower = Qwen2_5OmniAudioEncoder(
thinker_config.audio_config)
self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
else:
self.audio_tower = None
if multimodal_config.get_limit_per_prompt(
"image") or multimodal_config.get_limit_per_prompt("video"):
"image"
) or multimodal_config.get_limit_per_prompt("video"):
self.visual = Qwen2_5_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps",
1e-6),
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
@@ -820,7 +894,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}
@@ -828,28 +903,34 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values", "image_embeds"
) and "image" not in mm_input_by_modality:
mm_input_by_modality[
"image"] = self._parse_and_validate_image_input(**kwargs)
if input_key in ("pixel_values_videos", "video_embeds"
) and "video" not in mm_input_by_modality:
mm_input_by_modality[
"video"] = self._parse_and_validate_video_input(**kwargs)
if input_key in ("input_audio_features"
) and "audio" not in mm_input_by_modality:
mm_input_by_modality[
"audio"] = self._parse_and_validate_audio_input(**kwargs)
if (
input_key in ("pixel_values", "image_embeds")
and "image" not in mm_input_by_modality
):
mm_input_by_modality["image"] = self._parse_and_validate_image_input(
**kwargs
)
if (
input_key in ("pixel_values_videos", "video_embeds")
and "video" not in mm_input_by_modality
):
mm_input_by_modality["video"] = self._parse_and_validate_video_input(
**kwargs
)
if (
input_key in ("input_audio_features")
and "audio" not in mm_input_by_modality
):
mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
**kwargs
)
return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
**kwargs)
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
return []
@@ -893,8 +974,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
handle_oov_mm_token=handle_oov_mm_token,
)
def get_multimodal_embeddings_v0(
self, **kwargs: object) -> Optional[NestedTensors]:
def get_multimodal_embeddings_v0(self, **kwargs: object) -> Optional[NestedTensors]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
@@ -926,10 +1006,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
@@ -938,8 +1017,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = ["talker.", "token2wav."]
if self.audio_tower is None:
skip_prefixes.extend(["audio_tower."])
@@ -950,8 +1028,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
self,
skip_prefixes=skip_prefixes,
)
loaded_weights = loader.load_weights(weights,
mapper=self.hf_to_vllm_mapper)
loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loaded_weights
@@ -962,4 +1039,5 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="merger.",
tower_model=["visual.", "audio_tower."])
tower_model=["visual.", "audio_tower."],
)