Files
vllm/vllm/model_executor/models/glmasr.py
baonudesifeizhai d722e9e614 Add GLM-ASR multimodal support (#31436)
Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-31 23:12:24 +08:00

546 lines
19 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias, cast
import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature
from transformers.models.glmasr import GlmAsrConfig, GlmAsrEncoder, GlmAsrProcessor
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
DictEmbeddingItems,
ModalityData,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .audioflamingo3 import (
AudioFlamingo3MultiModalDataParser,
AudioFlamingo3MultiModalProcessor,
AudioFlamingo3ProcessingInfo,
)
from .audioflamingo3 import (
_audioflamingo3_field_config as _glmasr_field_config,
)
from .glmasr_utils import (
DEFAULT_CONV_PARAMS,
DEFAULT_MAX_AUDIO_LEN_S,
DEFAULT_MERGE_FACTOR,
_flatten_audio_features_by_length,
_get_audio_output_lengths_for_tower,
_get_num_features_for_item,
_group_audio_embeddings,
_normalize_chunk_counts,
)
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsTranscription,
)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .whisper import ISO639_1_SUPPORTED_LANGS
class GlmAsrFeatureInputs(TensorSchema):
"""
Dimensions:
- num_chunks: Number of audio chunks (flattened)
- nmb: Number of mel bins
- num_audios: Number of original audio files
"""
type: Literal["audio_features"]
input_features: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("num_chunks", "nmb", "chunk_length", dynamic_dims={"chunk_length"}),
]
feature_attention_mask: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("num_chunks", "chunk_length", dynamic_dims={"chunk_length"}),
]
chunk_counts: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("num_audios"),
]
class GlmAsrEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size
- naf: Number of audio features
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["audio_embeds"] = "audio_embeds"
audio_embeds: Annotated[
list[torch.Tensor],
TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}),
]
GlmAsrInputs: TypeAlias = GlmAsrFeatureInputs | GlmAsrEmbeddingInputs
class GlmAsrMultiModalProjector(nn.Module):
def __init__(
self,
config: GlmAsrConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.linear_1 = ColumnParallelLinear(
input_size=config.audio_config.intermediate_size,
output_size=config.text_config.hidden_size * 2,
quant_config=quant_config,
prefix=f"{prefix}.linear_1",
)
self.act = get_act_fn(config.projector_hidden_act)
self.linear_2 = RowParallelLinear(
input_size=config.text_config.hidden_size * 2,
output_size=config.text_config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_2",
)
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear_1(audio_features)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.linear_2(hidden_states)
return hidden_states
class GlmAsrProcessingInfo(AudioFlamingo3ProcessingInfo):
def get_hf_config(self) -> GlmAsrConfig:
return self.ctx.get_hf_config(GlmAsrConfig)
def get_hf_processor(self, **kwargs: object) -> GlmAsrProcessor:
return self.ctx.get_hf_processor(GlmAsrProcessor, **kwargs)
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
# Reuse parent implementation, but add type annotation and assertion
feature_extractor = super().get_feature_extractor(**kwargs)
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
class GlmAsrDummyInputsBuilder(BaseDummyInputsBuilder[GlmAsrProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
hf_processor = self.info.get_hf_processor()
return hf_processor.audio_token * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
max_audio_len = getattr(
self.info.get_hf_processor(), "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S
)
audio_len = int(max_audio_len * sampling_rate)
return {
"audio": self._get_dummy_audios(
length=audio_len, num_audios=num_audios, overrides=audio_overrides
)
}
class GlmAsrMultiModalDataParser(AudioFlamingo3MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[Any],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_glmasr_field_config,
)
return super()._parse_audio_data(data)
class GlmAsrMultiModalProcessor(AudioFlamingo3MultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return GlmAsrMultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _calculate_chunk_counts(
self,
audio_list: list[Any],
feature_extractor: WhisperFeatureExtractor,
processor: GlmAsrProcessor,
) -> list[int]:
"""Calculate chunk counts for each audio."""
sampling_rate = feature_extractor.sampling_rate
chunk_length = feature_extractor.chunk_length
max_audio_len = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S)
window_size = int(sampling_rate * chunk_length)
max_windows = int(max_audio_len // chunk_length)
chunk_counts = []
for audio in audio_list:
n_samples = len(audio) if isinstance(audio, list) else audio.shape[0]
n_chunks = max(1, (n_samples + window_size - 1) // window_size)
chunk_counts.append(min(n_chunks, max_windows))
return chunk_counts
def _call_hf_processor(
self,
prompt: str,
mm_data: dict[str, object],
mm_kwargs: Mapping[str, Any],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Normalize input: handle deprecated key and list conversion.
if "audios" in mm_data:
mm_data["audio"] = mm_data.pop("audios")
audio = mm_data.get("audio", [])
audio_list = [audio] if audio and not isinstance(audio, list) else audio
# Early return for text-only.
if not audio_list:
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
# Get processor for chunk counts calculation
processor = self.info.get_hf_processor(**mm_kwargs)
# Call parent method (it will handle sampling_rate)
outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
# Postprocess: rename mask and add chunk counts.
if "input_features_mask" in outputs:
outputs["feature_attention_mask"] = outputs.pop("input_features_mask")
# Override chunk counts calculation with GLM-ASR specific logic
chunk_counts = self._calculate_chunk_counts(
audio_list, processor.feature_extractor, processor
)
outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long)
return outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _glmasr_field_config(hf_inputs)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
config = self.info.get_hf_config()
audio_token = getattr(processor, "audio_token", "<|pad|>")
audio_token_id = vocab.get(audio_token)
if audio_token_id is None:
audio_token_id = processor.audio_token_id
merge_factor = getattr(config, "merge_factor", DEFAULT_MERGE_FACTOR)
out_mm_data = out_mm_kwargs.get_data()
feature_attention_mask = out_mm_data.get("feature_attention_mask")
chunk_counts = out_mm_data.get("chunk_counts")
def get_replacement_glmasr(item_idx: int):
conv_params = getattr(config, "conv_params", DEFAULT_CONV_PARAMS)
audio_embeds = out_mm_data.get("audio_embeds")
num_features = _get_num_features_for_item(
feature_attention_mask,
chunk_counts,
item_idx,
audio_embeds,
merge_factor,
conv_params,
)
if num_features == 0:
raise ValueError("Audio is too short")
audio_tokens = [audio_token_id] * int(num_features)
return PromptUpdateDetails.select_token_id(
audio_tokens,
embed_token_id=audio_token_id,
)
return [
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_replacement_glmasr,
)
]
@MULTIMODAL_REGISTRY.register_processor(
GlmAsrMultiModalProcessor,
info=GlmAsrProcessingInfo,
dummy_inputs=GlmAsrDummyInputsBuilder,
)
class GlmAsrForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
):
supported_languages = ISO639_1_SUPPORTED_LANGS
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.audio_tower = GlmAsrEncoder(config.audio_config)
self.multi_modal_projector = GlmAsrMultiModalProjector(
config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.quant_config = quant_config
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["LlamaForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("audio"):
return "<|begin_of_audio|><|pad|><|end_of_audio|>"
raise ValueError("Only audio modality is supported")
def get_mm_mapping(self) -> MultiModelKeys:
return MultiModelKeys.from_string_field(
language_model="language_model.",
connector="multi_modal_projector.",
tower_model="audio_tower.",
)
def _parse_and_validate_audio_input(self, **kwargs: object) -> GlmAsrInputs | None:
audio_embeds = kwargs.pop("audio_embeds", None)
if audio_embeds is not None:
return GlmAsrEmbeddingInputs(type="audio_embeds", audio_embeds=audio_embeds)
input_features = kwargs.pop("input_features", None)
if input_features is None:
return None
return GlmAsrFeatureInputs(
type="audio_features",
input_features=input_features,
feature_attention_mask=kwargs.pop("feature_attention_mask", None),
chunk_counts=kwargs.pop("chunk_counts", None),
)
def _process_audio_input(
self, audio_input: GlmAsrInputs
) -> torch.Tensor | tuple[torch.Tensor, ...]:
if audio_input["type"] == "audio_embeds":
return tuple(audio_input["audio_embeds"])
input_features = audio_input["input_features"]
feature_attention_mask = audio_input["feature_attention_mask"]
if isinstance(input_features, list):
input_features = torch.cat(input_features, dim=0)
feature_attention_mask = torch.cat(feature_attention_mask, dim=0)
num_chunks = input_features.shape[0]
chunk_counts = _normalize_chunk_counts(
audio_input.get("chunk_counts"), num_chunks=num_chunks
)
audio_hidden_states = self.audio_tower(input_features).last_hidden_state
audio_hidden_states = audio_hidden_states.reshape(
num_chunks,
-1,
self.config.audio_config.intermediate_size,
)
audio_features = self.multi_modal_projector(audio_hidden_states)
merge_factor = getattr(self.config, "merge_factor", DEFAULT_MERGE_FACTOR)
conv_params = getattr(self.config, "conv_params", DEFAULT_CONV_PARAMS)
audio_output_lengths = _get_audio_output_lengths_for_tower(
self.audio_tower,
feature_attention_mask.sum(-1),
merge_factor,
conv_params,
)
masked_audio_features = _flatten_audio_features_by_length(
audio_features, audio_output_lengths
)
chunk_embeddings = torch.split(
masked_audio_features, audio_output_lengths.flatten().tolist()
)
return _group_audio_embeddings(chunk_embeddings, chunk_counts)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return []
masked_audio_features = self._process_audio_input(audio_input)
return masked_audio_features
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = ["audio_tower.embed_positions"]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
@classmethod
def _get_audio_token(cls, model_config: ModelConfig) -> str:
"""Get the audio token from processor.
Similar to get_placeholder_str but returns single token.
"""
processor = cached_processor_from_config(model_config)
return getattr(processor, "audio_token", "<|pad|>")
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
processor = cached_processor_from_config(model_config)
feature_extractor = processor.feature_extractor
max_audio_clip_s = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S)
return SpeechToTextConfig(
max_audio_clip_s=max_audio_clip_s,
sample_rate=feature_extractor.sampling_rate,
)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
"""Get the generation prompt to be used for transcription requests."""
tokenizer = cached_tokenizer_from_config(model_config)
audio_token = cls._get_audio_token(model_config)
if task_type == "translate":
full_lang_name_to = cls.supported_languages.get(to_language, to_language)
user_content = f"{audio_token}translate the speech to {full_lang_name_to}"
elif task_type == "transcribe":
user_content = (
f"{audio_token}can you transcribe the speech into a written format?"
)
else:
raise ValueError(f"Unsupported task type {task_type}")
messages = [{"role": "user", "content": user_content}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt_token_ids = tokenizer.encode(prompt)
prompt_dict = {
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": {"audio": audio},
}
return cast(PromptType, prompt_dict)