Add GLM-ASR multimodal support (#31436)

Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
baonudesifeizhai
2025-12-31 10:12:24 -05:00
committed by GitHub
parent cf16342d43
commit d722e9e614
8 changed files with 764 additions and 2 deletions

View File

@@ -774,10 +774,11 @@ Speech2Text models trained specifically for Automatic Speech Recognition.
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|-------------------|----------------------|---------------------------|
| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | |
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ |
| `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-speech-3.3-2b`, `ibm-granite/granite-speech-3.3-8b`, etc. | ✅︎ | ✅︎ |
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | ✅︎ | ✅︎ |
| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | |
!!! note
`VoxtralForConditionalGeneration` requires `mistral-common[audio]` to be installed.

View File

@@ -358,6 +358,34 @@ def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
)
# GLM-ASR
def run_glmasr(question: str, audio_count: int) -> ModelRequestData:
model_name = "zai-org/GLM-ASR-Nano-2512"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# GLM-ASR uses <|pad|> token for audio
audio_placeholder = "<|pad|>" * audio_count
messages = [{"role": "user", "content": f"{audio_placeholder}{question}"}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# Whisper
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, "Whisper only support single audio input per prompt"
@@ -381,6 +409,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = {
"audioflamingo3": run_audioflamingo3,
"gemma3n": run_gemma3n,
"glmasr": run_glmasr,
"granite_speech": run_granite_speech,
"midashenglm": run_midashenglm,
"minicpmo": run_minicpmo,

View File

@@ -84,6 +84,19 @@ def qwen3_vl_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
return mm_data
def glmasr_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
"""
Patch the multimodal data for GLM-ASR model.
GLM-ASR requires text and audio to match 1:1, so we limit audio to 1.
"""
if "audio" in mm_data:
audio = mm_data["audio"]
if isinstance(audio, list) and len(audio) > 1:
# Limit to single audio to match text requirement
mm_data["audio"] = [audio[0]]
return mm_data
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
@@ -108,6 +121,7 @@ MM_DATA_PATCHES = {
"ernie4_5_moe_vl": qwen3_vl_patch_mm_data,
"glm4v": glm4_1v_patch_mm_data,
"glm4v_moe": glm4_1v_patch_mm_data,
"glmasr": glmasr_patch_mm_data,
"qwen3_vl": qwen3_vl_patch_mm_data,
"qwen3_vl_moe": qwen3_vl_patch_mm_data,
}

View File

@@ -655,6 +655,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"),
"GlmAsrForConditionalGeneration": _HfExamplesInfo(
"zai-org/GLM-ASR-Nano-2512",
trust_remote_code=True,
min_transformers_version="5.0",
),
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo(
"ibm-granite/granite-speech-3.3-2b"
),

View File

@@ -0,0 +1,545 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias, cast
import numpy as np
import torch
import torch.nn as nn
from transformers import BatchFeature
from transformers.models.glmasr import GlmAsrConfig, GlmAsrEncoder, GlmAsrProcessor
from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
DictEmbeddingItems,
ModalityData,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .audioflamingo3 import (
AudioFlamingo3MultiModalDataParser,
AudioFlamingo3MultiModalProcessor,
AudioFlamingo3ProcessingInfo,
)
from .audioflamingo3 import (
_audioflamingo3_field_config as _glmasr_field_config,
)
from .glmasr_utils import (
DEFAULT_CONV_PARAMS,
DEFAULT_MAX_AUDIO_LEN_S,
DEFAULT_MERGE_FACTOR,
_flatten_audio_features_by_length,
_get_audio_output_lengths_for_tower,
_get_num_features_for_item,
_group_audio_embeddings,
_normalize_chunk_counts,
)
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
SupportsTranscription,
)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .whisper import ISO639_1_SUPPORTED_LANGS
class GlmAsrFeatureInputs(TensorSchema):
"""
Dimensions:
- num_chunks: Number of audio chunks (flattened)
- nmb: Number of mel bins
- num_audios: Number of original audio files
"""
type: Literal["audio_features"]
input_features: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("num_chunks", "nmb", "chunk_length", dynamic_dims={"chunk_length"}),
]
feature_attention_mask: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("num_chunks", "chunk_length", dynamic_dims={"chunk_length"}),
]
chunk_counts: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("num_audios"),
]
class GlmAsrEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size
- naf: Number of audio features
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["audio_embeds"] = "audio_embeds"
audio_embeds: Annotated[
list[torch.Tensor],
TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}),
]
GlmAsrInputs: TypeAlias = GlmAsrFeatureInputs | GlmAsrEmbeddingInputs
class GlmAsrMultiModalProjector(nn.Module):
def __init__(
self,
config: GlmAsrConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.linear_1 = ColumnParallelLinear(
input_size=config.audio_config.intermediate_size,
output_size=config.text_config.hidden_size * 2,
quant_config=quant_config,
prefix=f"{prefix}.linear_1",
)
self.act = get_act_fn(config.projector_hidden_act)
self.linear_2 = RowParallelLinear(
input_size=config.text_config.hidden_size * 2,
output_size=config.text_config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_2",
)
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.linear_1(audio_features)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.linear_2(hidden_states)
return hidden_states
class GlmAsrProcessingInfo(AudioFlamingo3ProcessingInfo):
def get_hf_config(self) -> GlmAsrConfig:
return self.ctx.get_hf_config(GlmAsrConfig)
def get_hf_processor(self, **kwargs: object) -> GlmAsrProcessor:
return self.ctx.get_hf_processor(GlmAsrProcessor, **kwargs)
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
# Reuse parent implementation, but add type annotation and assertion
feature_extractor = super().get_feature_extractor(**kwargs)
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
class GlmAsrDummyInputsBuilder(BaseDummyInputsBuilder[GlmAsrProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
hf_processor = self.info.get_hf_processor()
return hf_processor.audio_token * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
max_audio_len = getattr(
self.info.get_hf_processor(), "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S
)
audio_len = int(max_audio_len * sampling_rate)
return {
"audio": self._get_dummy_audios(
length=audio_len, num_audios=num_audios, overrides=audio_overrides
)
}
class GlmAsrMultiModalDataParser(AudioFlamingo3MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[Any],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_glmasr_field_config,
)
return super()._parse_audio_data(data)
class GlmAsrMultiModalProcessor(AudioFlamingo3MultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return GlmAsrMultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _calculate_chunk_counts(
self,
audio_list: list[Any],
feature_extractor: WhisperFeatureExtractor,
processor: GlmAsrProcessor,
) -> list[int]:
"""Calculate chunk counts for each audio."""
sampling_rate = feature_extractor.sampling_rate
chunk_length = feature_extractor.chunk_length
max_audio_len = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S)
window_size = int(sampling_rate * chunk_length)
max_windows = int(max_audio_len // chunk_length)
chunk_counts = []
for audio in audio_list:
n_samples = len(audio) if isinstance(audio, list) else audio.shape[0]
n_chunks = max(1, (n_samples + window_size - 1) // window_size)
chunk_counts.append(min(n_chunks, max_windows))
return chunk_counts
def _call_hf_processor(
self,
prompt: str,
mm_data: dict[str, object],
mm_kwargs: Mapping[str, Any],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Normalize input: handle deprecated key and list conversion.
if "audios" in mm_data:
mm_data["audio"] = mm_data.pop("audios")
audio = mm_data.get("audio", [])
audio_list = [audio] if audio and not isinstance(audio, list) else audio
# Early return for text-only.
if not audio_list:
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
# Get processor for chunk counts calculation
processor = self.info.get_hf_processor(**mm_kwargs)
# Call parent method (it will handle sampling_rate)
outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
# Postprocess: rename mask and add chunk counts.
if "input_features_mask" in outputs:
outputs["feature_attention_mask"] = outputs.pop("input_features_mask")
# Override chunk counts calculation with GLM-ASR specific logic
chunk_counts = self._calculate_chunk_counts(
audio_list, processor.feature_extractor, processor
)
outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long)
return outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _glmasr_field_config(hf_inputs)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
config = self.info.get_hf_config()
audio_token = getattr(processor, "audio_token", "<|pad|>")
audio_token_id = vocab.get(audio_token)
if audio_token_id is None:
audio_token_id = processor.audio_token_id
merge_factor = getattr(config, "merge_factor", DEFAULT_MERGE_FACTOR)
out_mm_data = out_mm_kwargs.get_data()
feature_attention_mask = out_mm_data.get("feature_attention_mask")
chunk_counts = out_mm_data.get("chunk_counts")
def get_replacement_glmasr(item_idx: int):
conv_params = getattr(config, "conv_params", DEFAULT_CONV_PARAMS)
audio_embeds = out_mm_data.get("audio_embeds")
num_features = _get_num_features_for_item(
feature_attention_mask,
chunk_counts,
item_idx,
audio_embeds,
merge_factor,
conv_params,
)
if num_features == 0:
raise ValueError("Audio is too short")
audio_tokens = [audio_token_id] * int(num_features)
return PromptUpdateDetails.select_token_id(
audio_tokens,
embed_token_id=audio_token_id,
)
return [
PromptReplacement(
modality="audio",
target=audio_token,
replacement=get_replacement_glmasr,
)
]
@MULTIMODAL_REGISTRY.register_processor(
GlmAsrMultiModalProcessor,
info=GlmAsrProcessingInfo,
dummy_inputs=GlmAsrDummyInputsBuilder,
)
class GlmAsrForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
):
supported_languages = ISO639_1_SUPPORTED_LANGS
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.audio_tower = GlmAsrEncoder(config.audio_config)
self.multi_modal_projector = GlmAsrMultiModalProjector(
config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.quant_config = quant_config
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["LlamaForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("audio"):
return "<|begin_of_audio|><|pad|><|end_of_audio|>"
raise ValueError("Only audio modality is supported")
def get_mm_mapping(self) -> MultiModelKeys:
return MultiModelKeys.from_string_field(
language_model="language_model.",
connector="multi_modal_projector.",
tower_model="audio_tower.",
)
def _parse_and_validate_audio_input(self, **kwargs: object) -> GlmAsrInputs | None:
audio_embeds = kwargs.pop("audio_embeds", None)
if audio_embeds is not None:
return GlmAsrEmbeddingInputs(type="audio_embeds", audio_embeds=audio_embeds)
input_features = kwargs.pop("input_features", None)
if input_features is None:
return None
return GlmAsrFeatureInputs(
type="audio_features",
input_features=input_features,
feature_attention_mask=kwargs.pop("feature_attention_mask", None),
chunk_counts=kwargs.pop("chunk_counts", None),
)
def _process_audio_input(
self, audio_input: GlmAsrInputs
) -> torch.Tensor | tuple[torch.Tensor, ...]:
if audio_input["type"] == "audio_embeds":
return tuple(audio_input["audio_embeds"])
input_features = audio_input["input_features"]
feature_attention_mask = audio_input["feature_attention_mask"]
if isinstance(input_features, list):
input_features = torch.cat(input_features, dim=0)
feature_attention_mask = torch.cat(feature_attention_mask, dim=0)
num_chunks = input_features.shape[0]
chunk_counts = _normalize_chunk_counts(
audio_input.get("chunk_counts"), num_chunks=num_chunks
)
audio_hidden_states = self.audio_tower(input_features).last_hidden_state
audio_hidden_states = audio_hidden_states.reshape(
num_chunks,
-1,
self.config.audio_config.intermediate_size,
)
audio_features = self.multi_modal_projector(audio_hidden_states)
merge_factor = getattr(self.config, "merge_factor", DEFAULT_MERGE_FACTOR)
conv_params = getattr(self.config, "conv_params", DEFAULT_CONV_PARAMS)
audio_output_lengths = _get_audio_output_lengths_for_tower(
self.audio_tower,
feature_attention_mask.sum(-1),
merge_factor,
conv_params,
)
masked_audio_features = _flatten_audio_features_by_length(
audio_features, audio_output_lengths
)
chunk_embeddings = torch.split(
masked_audio_features, audio_output_lengths.flatten().tolist()
)
return _group_audio_embeddings(chunk_embeddings, chunk_counts)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return []
masked_audio_features = self._process_audio_input(audio_input)
return masked_audio_features
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
skip_prefixes = ["audio_tower.embed_positions"]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
@classmethod
def _get_audio_token(cls, model_config: ModelConfig) -> str:
"""Get the audio token from processor.
Similar to get_placeholder_str but returns single token.
"""
processor = cached_processor_from_config(model_config)
return getattr(processor, "audio_token", "<|pad|>")
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
processor = cached_processor_from_config(model_config)
feature_extractor = processor.feature_extractor
max_audio_clip_s = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S)
return SpeechToTextConfig(
max_audio_clip_s=max_audio_clip_s,
sample_rate=feature_extractor.sampling_rate,
)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
"""Get the generation prompt to be used for transcription requests."""
tokenizer = cached_tokenizer_from_config(model_config)
audio_token = cls._get_audio_token(model_config)
if task_type == "translate":
full_lang_name_to = cls.supported_languages.get(to_language, to_language)
user_content = f"{audio_token}translate the speech to {full_lang_name_to}"
elif task_type == "transcribe":
user_content = (
f"{audio_token}can you transcribe the speech into a written format?"
)
else:
raise ValueError(f"Unsupported task type {task_type}")
messages = [{"role": "user", "content": user_content}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
prompt_token_ids = tokenizer.encode(prompt)
prompt_dict = {
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": {"audio": audio},
}
return cast(PromptType, prompt_dict)

View File

@@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import cast
import torch
import torch.nn as nn
DEFAULT_MAX_AUDIO_LEN_S = 655
DEFAULT_MERGE_FACTOR = 4
# Default convolution parameters: (padding, kernel_size, stride)
# These correspond to the two conv layers in GlmAsrEncoder
DEFAULT_CONV_PARAMS = [(1, 3, 1), (1, 3, 2)]
def _calculate_conv_output_length(
input_length: torch.Tensor, padding: int, kernel_size: int, stride: int
) -> torch.Tensor:
"""Calculate Conv1d output length using standard formula."""
# Standard formula: floor((input + 2*padding - kernel_size) / stride) + 1
return (input_length + 2 * padding - kernel_size) // stride + 1
def _as_list_chunk_counts(
chunk_counts: torch.Tensor | list[int] | list[torch.Tensor],
) -> list[int]:
if isinstance(chunk_counts, torch.Tensor):
return chunk_counts.tolist()
if chunk_counts and isinstance(chunk_counts[0], torch.Tensor):
tensor_counts = cast(list[torch.Tensor], chunk_counts)
return [int(c.item()) for c in tensor_counts]
return [int(c) for c in chunk_counts]
def _normalize_chunk_counts(
chunk_counts: torch.Tensor | list[int] | list[torch.Tensor] | None,
num_chunks: int,
) -> list[int]:
if chunk_counts is None:
return [1] * num_chunks
return _as_list_chunk_counts(chunk_counts)
def _get_audio_output_lengths_from_lengths(
audio_lengths: torch.Tensor,
merge_factor: int,
conv_params: list[tuple[int, int, int]],
) -> torch.Tensor:
for padding, kernel_size, stride in conv_params:
audio_lengths = _calculate_conv_output_length(
audio_lengths, padding, kernel_size, stride
)
return (audio_lengths - merge_factor) // merge_factor + 1
def _get_audio_output_lengths_from_mask(
mask: torch.Tensor,
merge_factor: int,
conv_params: list[tuple[int, int, int]],
) -> torch.Tensor:
audio_lengths = mask.sum(-1)
return _get_audio_output_lengths_from_lengths(
audio_lengths, merge_factor, conv_params
)
def _get_audio_output_lengths_for_tower(
audio_tower: nn.Module,
audio_lengths: torch.Tensor,
merge_factor: int,
conv_params: list[tuple[int, int, int]],
) -> torch.Tensor:
if hasattr(audio_tower, "_get_feat_extract_output_lengths"):
_, audio_output_lengths = audio_tower._get_feat_extract_output_lengths(
audio_lengths
)
return audio_output_lengths
return _get_audio_output_lengths_from_lengths(
audio_lengths, merge_factor, conv_params
)
def _flatten_audio_features_by_length(
audio_features: torch.Tensor,
audio_output_lengths: torch.Tensor,
) -> torch.Tensor:
num_chunks, max_audio_tokens, embed_dim = audio_features.shape
audio_output_lengths = audio_output_lengths.unsqueeze(1)
audio_features_mask = (
torch.arange(max_audio_tokens)
.expand(num_chunks, max_audio_tokens)
.to(audio_output_lengths.device)
< audio_output_lengths
)
return audio_features[audio_features_mask].view(-1, embed_dim)
def _group_audio_embeddings(
chunk_embeddings: Sequence[torch.Tensor],
chunk_counts: Sequence[int],
) -> tuple[torch.Tensor, ...]:
grouped_embeddings = []
current_idx = 0
for count in chunk_counts:
audio_chunks = chunk_embeddings[current_idx : current_idx + count]
grouped_embeddings.append(torch.cat(audio_chunks, dim=0))
current_idx += count
return tuple(grouped_embeddings)
def _normalize_to_tensor(mask: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
"""Convert mask to tensor, handling both list and tensor formats."""
if isinstance(mask, list):
return (
torch.stack(mask)
if mask and isinstance(mask[0], torch.Tensor)
else torch.tensor(mask)
)
return mask
def _extract_mask_for_item(
feature_attention_mask: torch.Tensor | list[torch.Tensor],
chunk_counts: torch.Tensor | list[int] | None,
item_idx: int,
) -> torch.Tensor:
"""Extract attention mask for a specific audio item."""
if chunk_counts is None:
# Single item per audio
mask = feature_attention_mask[item_idx]
if isinstance(feature_attention_mask, torch.Tensor):
return mask.unsqueeze(0)
return _normalize_to_tensor(mask)
# Multiple chunks per audio: calculate slice indices
counts = _as_list_chunk_counts(chunk_counts)
start_idx = sum(counts[:item_idx])
end_idx = start_idx + counts[item_idx]
# Extract slice
if isinstance(feature_attention_mask, torch.Tensor):
return feature_attention_mask[start_idx:end_idx]
mask_slice = feature_attention_mask[start_idx:end_idx]
return _normalize_to_tensor(mask_slice)
def _get_num_features_for_item(
feature_attention_mask: torch.Tensor | None,
chunk_counts: torch.Tensor | list[int] | None,
item_idx: int,
audio_embeds: list[torch.Tensor] | None,
merge_factor: int,
conv_params: list[tuple[int, int, int]],
) -> int:
"""Get number of features for a specific audio item."""
if feature_attention_mask is not None:
mask = _extract_mask_for_item(feature_attention_mask, chunk_counts, item_idx)
audio_output_lengths = _get_audio_output_lengths_from_mask(
mask, merge_factor, conv_params
)
return audio_output_lengths.sum().item()
if audio_embeds is not None:
return audio_embeds[item_idx].shape[0]
raise ValueError("Either feature_attention_mask or audio_embeds must be provided")

View File

@@ -304,6 +304,7 @@ _MULTIMODAL_MODELS = {
"gemma3n_mm",
"Gemma3nForConditionalGeneration",
),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), # noqa: E501
"Glm4vMoeForConditionalGeneration": ("glm4_1v", "Glm4vMoeForConditionalGeneration"), # noqa: E501

View File

@@ -136,6 +136,7 @@ class HFConfigParser(ConfigParserBase):
model,
revision=revision,
code_revision=code_revision,
trust_remote_code=trust_remote_code,
token=_get_hf_token(),
**kwargs,
)
@@ -157,6 +158,7 @@ class HFConfigParser(ConfigParserBase):
model,
revision=revision,
code_revision=code_revision,
trust_remote_code=trust_remote_code,
token=_get_hf_token(),
**kwargs,
)