Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -22,29 +22,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, Optional, Union
import torch
import torch.nn as nn
from transformers import BatchFeature
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
Qwen2AudioEncoder,
Qwen2AudioProcessor)
from transformers.models.qwen2_audio import (
Qwen2AudioConfig,
Qwen2AudioEncoder,
Qwen2AudioProcessor,
)
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (AudioItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.inputs import (
AudioItem,
ModalityData,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
DictEmbeddingItems,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -60,6 +75,7 @@ class Qwen2AudioFeatureInputs(TensorSchema):
- na: Number of audios
- nmb: Number of mel bins
"""
type: Literal["audio_features"]
input_features: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
@@ -80,6 +96,7 @@ class Qwen2AudioEmbeddingInputs(TensorSchema):
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["audio_embeds"] = "audio_embeds"
audio_embeds: Annotated[
@@ -94,7 +111,6 @@ Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]
class Qwen2AudioMultiModalProjector(nn.Module):
def __init__(self, audio_hidden_size: int, text_hidden_size: int):
super().__init__()
self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True)
@@ -112,15 +128,13 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2AudioConfig)
def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs)
def get_feature_extractor(self,
**kwargs: object) -> WhisperFeatureExtractor:
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
@@ -130,9 +144,7 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
return {"audio": None}
class Qwen2AudioDummyInputsBuilder(
BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
@@ -156,10 +168,9 @@ class Qwen2AudioDummyInputsBuilder(
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio":
self._get_dummy_audios(length=audio_len,
num_audios=num_audios,
overrides=audio_overrides)
"audio": self._get_dummy_audios(
length=audio_len, num_audios=num_audios, overrides=audio_overrides
)
}
@@ -172,7 +183,6 @@ def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
@@ -188,13 +198,10 @@ class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
return super()._parse_audio_data(data)
class Qwen2AudioMultiModalProcessor(
BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return Qwen2AudioMultiModalDataParser(
target_sr=feature_extractor.sampling_rate)
return Qwen2AudioMultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
self,
@@ -242,17 +249,14 @@ class Qwen2AudioMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
# Use getattr with default to be compatible with transformers<4.48
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
audio_bos_token = getattr(processor, "audio_bos_token",
"<|audio_bos|>")
audio_eos_token = getattr(processor, "audio_eos_token",
"<|audio_eos|>")
audio_bos_token = getattr(processor, "audio_bos_token", "<|audio_bos|>")
audio_eos_token = getattr(processor, "audio_eos_token", "<|audio_eos|>")
audio_token_id = vocab[audio_token]
audio_bos_id = vocab[audio_bos_token]
@@ -265,26 +269,27 @@ class Qwen2AudioMultiModalProcessor(
else:
assert isinstance(feature_attention_mask, torch.Tensor)
_, audio_output_lens = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1))
feature_attention_mask.sum(-1)
)
audio_output_lengths = audio_output_lens.tolist()
def get_replacement_qwen2_audio(item_idx: int):
if audio_output_lengths:
num_features = audio_output_lengths[item_idx]
else:
audio_embeds = out_mm_data["audio_embeds"][item_idx]
assert len(audio_embeds.shape
) == 2, "audio_embeds must be a 2D tensor"
assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
num_features = audio_embeds.shape[0]
if num_features == 0:
audios = mm_items.get_items("audio", AudioProcessorItems)
audio_len = audios.get_audio_length(item_idx)
raise ValueError(f"The audio (len={audio_len}) is too short "
"to be represented inside the model")
raise ValueError(
f"The audio (len={audio_len}) is too short "
"to be represented inside the model"
)
audio_tokens = [audio_token_id] * num_features
@@ -305,10 +310,9 @@ class Qwen2AudioMultiModalProcessor(
@MULTIMODAL_REGISTRY.register_processor(
Qwen2AudioMultiModalProcessor,
info=Qwen2AudioProcessingInfo,
dummy_inputs=Qwen2AudioDummyInputsBuilder)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
dummy_inputs=Qwen2AudioDummyInputsBuilder,
)
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("audio"):
@@ -326,7 +330,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
self.audio_tower = Qwen2AudioEncoder(config.audio_config)
self.multi_modal_projector = Qwen2AudioMultiModalProjector(
config.audio_config.d_model, config.text_config.hidden_size)
config.audio_config.d_model, config.text_config.hidden_size
)
self.quant_config = quant_config
@@ -338,45 +343,53 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.language_model.make_empty_intermediate_tensors
)
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
def _validate_and_reshape_mm_tensor(
self, mm_input: object, name: str
) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
return mm_input.reshape(-1, *mm_input.shape[2:])
else:
return torch.concat(mm_input)
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
input_features = kwargs.pop('input_features', None)
audio_embeds = kwargs.pop('audio_embeds', None)
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
self, **kwargs: object
) -> Optional[Qwen2AudioInputs]:
input_features = kwargs.pop("input_features", None)
audio_embeds = kwargs.pop("audio_embeds", None)
feature_attention_mask = kwargs.pop("feature_attention_mask", None)
if input_features is None and audio_embeds is None:
return None
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
raise ValueError(
f"Incorrect type of audio embeds. Got type: {type(audio_embeds)}"
)
audio_embeds = self._validate_and_reshape_mm_tensor(
audio_embeds, "audio_embeds")
return Qwen2AudioEmbeddingInputs(type="audio_embeds",
audio_embeds=audio_embeds)
audio_embeds, "audio_embeds"
)
return Qwen2AudioEmbeddingInputs(
type="audio_embeds", audio_embeds=audio_embeds
)
if input_features is not None:
input_features = self._validate_and_reshape_mm_tensor(
input_features, 'input_features')
input_features, "input_features"
)
feature_attention_mask = self._validate_and_reshape_mm_tensor(
feature_attention_mask, 'feature_attention_mask')
feature_attention_mask, "feature_attention_mask"
)
return Qwen2AudioFeatureInputs(
type="audio_features",
input_features=input_features,
feature_attention_mask=feature_attention_mask)
feature_attention_mask=feature_attention_mask,
)
raise AssertionError("This line should be unreachable.")
@@ -392,51 +405,62 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
audio_feat_lengths, audio_output_lengths = (
self.audio_tower._get_feat_extract_output_lengths(
feature_attention_mask.sum(-1)))
feature_attention_mask.sum(-1)
)
)
batch_size, _, max_mel_seq_len = input_features.shape
max_seq_len = (max_mel_seq_len - 2) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (torch.arange(
0,
max_seq_len,
dtype=audio_feat_lengths.dtype,
device=audio_feat_lengths.device).unsqueeze(0).expand(
batch_size, max_seq_len))
seq_range = (
torch.arange(
0,
max_seq_len,
dtype=audio_feat_lengths.dtype,
device=audio_feat_lengths.device,
)
.unsqueeze(0)
.expand(batch_size, max_seq_len)
)
lengths_expand = audio_feat_lengths.unsqueeze(-1).expand(
batch_size, max_seq_len)
batch_size, max_seq_len
)
# Create mask
padding_mask = seq_range >= lengths_expand
audio_attention_mask_ = padding_mask.view(
batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
max_seq_len)
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
batch_size, 1, max_seq_len, max_seq_len
)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.audio_tower.conv1.weight.dtype,
device=self.audio_tower.conv1.weight.device)
device=self.audio_tower.conv1.weight.device,
)
audio_attention_mask[audio_attention_mask_] = float("-inf")
audio_outputs = self.audio_tower(input_features,
attention_mask=audio_attention_mask)
audio_outputs = self.audio_tower(
input_features, attention_mask=audio_attention_mask
)
selected_audio_feature = audio_outputs.last_hidden_state
audio_features = self.multi_modal_projector(selected_audio_feature)
num_audios, max_audio_tokens, embed_dim = audio_features.shape
audio_output_lengths = audio_output_lengths.unsqueeze(1)
audio_features_mask = torch.arange(max_audio_tokens).expand(
num_audios, max_audio_tokens).to(
audio_output_lengths.device) < audio_output_lengths
masked_audio_features = audio_features[audio_features_mask].view(
-1, embed_dim)
audio_features_mask = (
torch.arange(max_audio_tokens)
.expand(num_audios, max_audio_tokens)
.to(audio_output_lengths.device)
< audio_output_lengths
)
masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)
# Split to tuple of embeddings for individual audio input.
return torch.split(masked_audio_features,
audio_output_lengths.flatten().tolist())
return torch.split(
masked_audio_features, audio_output_lengths.flatten().tolist()
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return []
@@ -451,14 +475,12 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def compute_logits(
@@ -467,7 +489,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)