Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -22,29 +22,44 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature
|
||||
from transformers.models.qwen2_audio import (Qwen2AudioConfig,
|
||||
Qwen2AudioEncoder,
|
||||
Qwen2AudioProcessor)
|
||||
from transformers.models.qwen2_audio import (
|
||||
Qwen2AudioConfig,
|
||||
Qwen2AudioEncoder,
|
||||
Qwen2AudioProcessor,
|
||||
)
|
||||
from transformers.models.whisper import WhisperFeatureExtractor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (AudioItem, ModalityData,
|
||||
MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
|
||||
ModalityDataItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.inputs import (
|
||||
AudioItem,
|
||||
ModalityData,
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
AudioProcessorItems,
|
||||
DictEmbeddingItems,
|
||||
ModalityDataItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@@ -60,6 +75,7 @@ class Qwen2AudioFeatureInputs(TensorSchema):
|
||||
- na: Number of audios
|
||||
- nmb: Number of mel bins
|
||||
"""
|
||||
|
||||
type: Literal["audio_features"]
|
||||
input_features: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
@@ -80,6 +96,7 @@ class Qwen2AudioEmbeddingInputs(TensorSchema):
|
||||
- hs: Hidden size (must match the hidden size of language model
|
||||
backbone)
|
||||
"""
|
||||
|
||||
type: Literal["audio_embeds"] = "audio_embeds"
|
||||
|
||||
audio_embeds: Annotated[
|
||||
@@ -94,7 +111,6 @@ Qwen2AudioInputs = Union[Qwen2AudioFeatureInputs, Qwen2AudioEmbeddingInputs]
|
||||
|
||||
|
||||
class Qwen2AudioMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, audio_hidden_size: int, text_hidden_size: int):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True)
|
||||
@@ -112,15 +128,13 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
|
||||
|
||||
|
||||
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Qwen2AudioConfig)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor:
|
||||
return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs)
|
||||
|
||||
def get_feature_extractor(self,
|
||||
**kwargs: object) -> WhisperFeatureExtractor:
|
||||
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
|
||||
hf_processor = self.get_hf_processor(**kwargs)
|
||||
feature_extractor = hf_processor.feature_extractor # type: ignore
|
||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||
@@ -130,9 +144,7 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
|
||||
return {"audio": None}
|
||||
|
||||
|
||||
class Qwen2AudioDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
|
||||
|
||||
class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
@@ -156,10 +168,9 @@ class Qwen2AudioDummyInputsBuilder(
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
"audio": self._get_dummy_audios(
|
||||
length=audio_len, num_audios=num_audios, overrides=audio_overrides
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -172,7 +183,6 @@ def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
|
||||
|
||||
class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
|
||||
@@ -188,13 +198,10 @@ class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
|
||||
return super()._parse_audio_data(data)
|
||||
|
||||
|
||||
class Qwen2AudioMultiModalProcessor(
|
||||
BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
|
||||
|
||||
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return Qwen2AudioMultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate)
|
||||
return Qwen2AudioMultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@@ -242,17 +249,14 @@ class Qwen2AudioMultiModalProcessor(
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
# Use getattr with default to be compatible with transformers<4.48
|
||||
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
|
||||
audio_bos_token = getattr(processor, "audio_bos_token",
|
||||
"<|audio_bos|>")
|
||||
audio_eos_token = getattr(processor, "audio_eos_token",
|
||||
"<|audio_eos|>")
|
||||
audio_bos_token = getattr(processor, "audio_bos_token", "<|audio_bos|>")
|
||||
audio_eos_token = getattr(processor, "audio_eos_token", "<|audio_eos|>")
|
||||
|
||||
audio_token_id = vocab[audio_token]
|
||||
audio_bos_id = vocab[audio_bos_token]
|
||||
@@ -265,26 +269,27 @@ class Qwen2AudioMultiModalProcessor(
|
||||
else:
|
||||
assert isinstance(feature_attention_mask, torch.Tensor)
|
||||
_, audio_output_lens = _get_feat_extract_output_lengths(
|
||||
feature_attention_mask.sum(-1))
|
||||
feature_attention_mask.sum(-1)
|
||||
)
|
||||
|
||||
audio_output_lengths = audio_output_lens.tolist()
|
||||
|
||||
def get_replacement_qwen2_audio(item_idx: int):
|
||||
|
||||
if audio_output_lengths:
|
||||
num_features = audio_output_lengths[item_idx]
|
||||
else:
|
||||
audio_embeds = out_mm_data["audio_embeds"][item_idx]
|
||||
assert len(audio_embeds.shape
|
||||
) == 2, "audio_embeds must be a 2D tensor"
|
||||
assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
|
||||
num_features = audio_embeds.shape[0]
|
||||
|
||||
if num_features == 0:
|
||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||
audio_len = audios.get_audio_length(item_idx)
|
||||
|
||||
raise ValueError(f"The audio (len={audio_len}) is too short "
|
||||
"to be represented inside the model")
|
||||
raise ValueError(
|
||||
f"The audio (len={audio_len}) is too short "
|
||||
"to be represented inside the model"
|
||||
)
|
||||
|
||||
audio_tokens = [audio_token_id] * num_features
|
||||
|
||||
@@ -305,10 +310,9 @@ class Qwen2AudioMultiModalProcessor(
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2AudioMultiModalProcessor,
|
||||
info=Qwen2AudioProcessingInfo,
|
||||
dummy_inputs=Qwen2AudioDummyInputsBuilder)
|
||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
dummy_inputs=Qwen2AudioDummyInputsBuilder,
|
||||
)
|
||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("audio"):
|
||||
@@ -326,7 +330,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
self.audio_tower = Qwen2AudioEncoder(config.audio_config)
|
||||
self.multi_modal_projector = Qwen2AudioMultiModalProjector(
|
||||
config.audio_config.d_model, config.text_config.hidden_size)
|
||||
config.audio_config.d_model, config.text_config.hidden_size
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
@@ -338,45 +343,53 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
def _validate_and_reshape_mm_tensor(
|
||||
self, mm_input: object, name: str
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
f"Got type: {type(mm_input)}")
|
||||
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
|
||||
input_features = kwargs.pop('input_features', None)
|
||||
audio_embeds = kwargs.pop('audio_embeds', None)
|
||||
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
|
||||
self, **kwargs: object
|
||||
) -> Optional[Qwen2AudioInputs]:
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
audio_embeds = kwargs.pop("audio_embeds", None)
|
||||
feature_attention_mask = kwargs.pop("feature_attention_mask", None)
|
||||
|
||||
if input_features is None and audio_embeds is None:
|
||||
return None
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio embeds. "
|
||||
f"Got type: {type(audio_embeds)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of audio embeds. Got type: {type(audio_embeds)}"
|
||||
)
|
||||
audio_embeds = self._validate_and_reshape_mm_tensor(
|
||||
audio_embeds, "audio_embeds")
|
||||
return Qwen2AudioEmbeddingInputs(type="audio_embeds",
|
||||
audio_embeds=audio_embeds)
|
||||
audio_embeds, "audio_embeds"
|
||||
)
|
||||
return Qwen2AudioEmbeddingInputs(
|
||||
type="audio_embeds", audio_embeds=audio_embeds
|
||||
)
|
||||
|
||||
if input_features is not None:
|
||||
input_features = self._validate_and_reshape_mm_tensor(
|
||||
input_features, 'input_features')
|
||||
input_features, "input_features"
|
||||
)
|
||||
feature_attention_mask = self._validate_and_reshape_mm_tensor(
|
||||
feature_attention_mask, 'feature_attention_mask')
|
||||
feature_attention_mask, "feature_attention_mask"
|
||||
)
|
||||
return Qwen2AudioFeatureInputs(
|
||||
type="audio_features",
|
||||
input_features=input_features,
|
||||
feature_attention_mask=feature_attention_mask)
|
||||
feature_attention_mask=feature_attention_mask,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
@@ -392,51 +405,62 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
audio_feat_lengths, audio_output_lengths = (
|
||||
self.audio_tower._get_feat_extract_output_lengths(
|
||||
feature_attention_mask.sum(-1)))
|
||||
feature_attention_mask.sum(-1)
|
||||
)
|
||||
)
|
||||
|
||||
batch_size, _, max_mel_seq_len = input_features.shape
|
||||
max_seq_len = (max_mel_seq_len - 2) // 2 + 1
|
||||
# Create a sequence tensor of shape (batch_size, max_seq_len)
|
||||
seq_range = (torch.arange(
|
||||
0,
|
||||
max_seq_len,
|
||||
dtype=audio_feat_lengths.dtype,
|
||||
device=audio_feat_lengths.device).unsqueeze(0).expand(
|
||||
batch_size, max_seq_len))
|
||||
seq_range = (
|
||||
torch.arange(
|
||||
0,
|
||||
max_seq_len,
|
||||
dtype=audio_feat_lengths.dtype,
|
||||
device=audio_feat_lengths.device,
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size, max_seq_len)
|
||||
)
|
||||
lengths_expand = audio_feat_lengths.unsqueeze(-1).expand(
|
||||
batch_size, max_seq_len)
|
||||
batch_size, max_seq_len
|
||||
)
|
||||
# Create mask
|
||||
padding_mask = seq_range >= lengths_expand
|
||||
|
||||
audio_attention_mask_ = padding_mask.view(
|
||||
batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
|
||||
max_seq_len)
|
||||
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
|
||||
batch_size, 1, max_seq_len, max_seq_len
|
||||
)
|
||||
audio_attention_mask = audio_attention_mask_.to(
|
||||
dtype=self.audio_tower.conv1.weight.dtype,
|
||||
device=self.audio_tower.conv1.weight.device)
|
||||
device=self.audio_tower.conv1.weight.device,
|
||||
)
|
||||
audio_attention_mask[audio_attention_mask_] = float("-inf")
|
||||
|
||||
audio_outputs = self.audio_tower(input_features,
|
||||
attention_mask=audio_attention_mask)
|
||||
audio_outputs = self.audio_tower(
|
||||
input_features, attention_mask=audio_attention_mask
|
||||
)
|
||||
selected_audio_feature = audio_outputs.last_hidden_state
|
||||
audio_features = self.multi_modal_projector(selected_audio_feature)
|
||||
num_audios, max_audio_tokens, embed_dim = audio_features.shape
|
||||
audio_output_lengths = audio_output_lengths.unsqueeze(1)
|
||||
audio_features_mask = torch.arange(max_audio_tokens).expand(
|
||||
num_audios, max_audio_tokens).to(
|
||||
audio_output_lengths.device) < audio_output_lengths
|
||||
masked_audio_features = audio_features[audio_features_mask].view(
|
||||
-1, embed_dim)
|
||||
audio_features_mask = (
|
||||
torch.arange(max_audio_tokens)
|
||||
.expand(num_audios, max_audio_tokens)
|
||||
.to(audio_output_lengths.device)
|
||||
< audio_output_lengths
|
||||
)
|
||||
masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)
|
||||
|
||||
# Split to tuple of embeddings for individual audio input.
|
||||
return torch.split(masked_audio_features,
|
||||
audio_output_lengths.flatten().tolist())
|
||||
return torch.split(
|
||||
masked_audio_features, audio_output_lengths.flatten().tolist()
|
||||
)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is None:
|
||||
return []
|
||||
@@ -451,14 +475,12 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
@@ -467,7 +489,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user