1167 lines
42 KiB
Python
1167 lines
42 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from typing import Annotated, Any, Literal, TypeAlias
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import BatchFeature
|
|
from transformers.models.glmasr import GlmAsrConfig, GlmAsrProcessor
|
|
from transformers.models.whisper import WhisperFeatureExtractor
|
|
|
|
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
|
from vllm.config.multimodal import BaseDummyOptions
|
|
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
|
from vllm.inputs import ModalityData, MultiModalDataDict, PromptType, TokensPrompt
|
|
from vllm.model_executor.layers.activation import get_act_fn
|
|
from vllm.model_executor.layers.attention import MMEncoderAttention
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding.common import ApplyRotaryEmb
|
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalFieldConfig,
|
|
MultiModalKwargsItems,
|
|
)
|
|
from vllm.multimodal.parse import (
|
|
DictEmbeddingItems,
|
|
ModalityDataItems,
|
|
MultiModalDataItems,
|
|
MultiModalDataParser,
|
|
)
|
|
from vllm.multimodal.processing import (
|
|
BaseDummyInputsBuilder,
|
|
BaseMultiModalProcessor,
|
|
BaseProcessingInfo,
|
|
PromptReplacement,
|
|
PromptUpdate,
|
|
PromptUpdateDetails,
|
|
)
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.tokenizers import cached_tokenizer_from_config
|
|
from vllm.transformers_utils.processor import cached_processor_from_config
|
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|
|
|
from .glmasr_utils import (
|
|
DEFAULT_CONV_PARAMS,
|
|
DEFAULT_MAX_AUDIO_LEN_S,
|
|
DEFAULT_MERGE_FACTOR,
|
|
_flatten_audio_features_by_length,
|
|
_get_audio_output_lengths_for_tower,
|
|
_group_audio_embeddings,
|
|
_normalize_chunk_counts,
|
|
)
|
|
from .interfaces import (
|
|
MultiModalEmbeddings,
|
|
SupportsLoRA,
|
|
SupportsMultiModal,
|
|
SupportsPP,
|
|
SupportsTranscription,
|
|
)
|
|
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
|
from .whisper import ISO639_1_SUPPORTED_LANGS
|
|
|
|
|
|
class GlmAsrEncoderRotaryEmbedding(nn.Module):
|
|
"""
|
|
Rotary Position Embedding for GLM-ASR encoder.
|
|
|
|
Computes rotary position embeddings on-demand for efficiency.
|
|
Only caches inv_freq as a buffer; cos/sin are computed during forward
|
|
to avoid wasted computation during initialization and ensure correct
|
|
device placement.
|
|
"""
|
|
|
|
def __init__(self, config) -> None:
|
|
super().__init__()
|
|
|
|
# Compute inverse frequencies following transformers implementation
|
|
head_dim = getattr(
|
|
config, "head_dim", config.hidden_size // config.num_attention_heads
|
|
)
|
|
|
|
# Handle rope_parameters if present (for compatibility with transformers config)
|
|
if hasattr(config, "rope_parameters") and config.rope_parameters:
|
|
base = config.rope_parameters.get("rope_theta", 10000.0)
|
|
partial_rotary_factor = config.rope_parameters.get(
|
|
"partial_rotary_factor", 1.0
|
|
)
|
|
dim = int(head_dim * partial_rotary_factor)
|
|
self.attention_scaling = config.rope_parameters.get(
|
|
"attention_scaling", 1.0
|
|
)
|
|
else:
|
|
base = getattr(config, "rope_theta", 10000.0)
|
|
dim = head_dim
|
|
self.attention_scaling = 1.0
|
|
|
|
self.dim = dim
|
|
self.head_dim = head_dim
|
|
|
|
# Only cache inv_freq; cos/sin computed on-demand in correct device
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
|
def forward(self, seq_len: int) -> torch.Tensor:
|
|
"""
|
|
Compute rotary position frequencies for given sequence length.
|
|
|
|
Args:
|
|
seq_len: The sequence length to compute embeddings for.
|
|
|
|
Returns:
|
|
Frequency tensor with shape [seq_len, dim/2]. Use .cos() and
|
|
.sin() to get the rotary embedding components.
|
|
"""
|
|
# Compute on the same device as inv_freq (automatically correct after .to())
|
|
seq = torch.arange(
|
|
seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
|
)
|
|
freqs = torch.outer(seq, self.inv_freq)
|
|
return freqs * self.attention_scaling
|
|
|
|
|
|
class GlmAsrEncoderAttention(nn.Module):
|
|
"""
|
|
Optimized Multi-headed Grouped Query Attention for GLM-ASR encoder.
|
|
|
|
Uses vLLM's QKVParallelLinear for fused projections, ApplyRotaryEmb for
|
|
rotary position embeddings, and MMEncoderAttention for hardware-optimized
|
|
attention computation with automatic backend selection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.num_kv_heads = getattr(
|
|
config, "num_key_value_heads", config.num_attention_heads
|
|
)
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.num_heads_per_rank = self.num_heads // self.tp_size
|
|
self.num_kv_heads_per_rank = max(1, self.num_kv_heads // self.tp_size)
|
|
|
|
# Use QKVParallelLinear for fused QKV projection
|
|
# Note: GLM-ASR uses bias on Q and V, but not K
|
|
# For simplicity with QKVParallelLinear, we use bias=True for all
|
|
self.qkv_proj = QKVParallelLinear(
|
|
self.hidden_size,
|
|
self.head_dim,
|
|
self.num_heads,
|
|
self.num_kv_heads,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.qkv_proj",
|
|
)
|
|
|
|
self.o_proj = RowParallelLinear(
|
|
self.hidden_size,
|
|
self.hidden_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
|
|
# Use vLLM's ApplyRotaryEmb CustomOp
|
|
# enforce_enable=True ensures the op is always enabled (important for ViT)
|
|
rope_params = getattr(config, "rope_parameters", None)
|
|
if rope_params:
|
|
partial_rotary_factor = rope_params.get("partial_rotary_factor", 0.5)
|
|
else:
|
|
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
|
|
self.rotary_dim = int(self.head_dim * partial_rotary_factor)
|
|
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
|
|
|
# Use vLLM's MMEncoderAttention for hardware-optimized attention
|
|
# Automatically selects Flash Attention, SDPA, or Pallas based on device
|
|
self.attn = MMEncoderAttention(
|
|
num_heads=self.num_heads_per_rank,
|
|
head_size=self.head_dim,
|
|
scale=self.head_dim**-0.5,
|
|
num_kv_heads=self.num_kv_heads_per_rank,
|
|
prefix=f"{prefix}.attn",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
rotary_pos_emb_cos: torch.Tensor,
|
|
rotary_pos_emb_sin: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states: [batch_size, seq_len, hidden_size]
|
|
rotary_pos_emb_cos: [seq_len, rotary_dim/2] - cosine of rotary embeddings
|
|
rotary_pos_emb_sin: [seq_len, rotary_dim/2] - sine of rotary embeddings
|
|
|
|
Returns:
|
|
[batch_size, seq_len, hidden_size]
|
|
"""
|
|
batch_size, seq_len, _ = hidden_states.shape
|
|
|
|
# QKV projection - fused for efficiency
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
|
|
# Split into q, k, v
|
|
q_size = self.num_heads_per_rank * self.head_dim
|
|
kv_size = self.num_kv_heads_per_rank * self.head_dim
|
|
q, k, v = qkv.split([q_size, kv_size, kv_size], dim=-1)
|
|
|
|
# Reshape to [batch, seq, num_heads, head_dim] for ApplyRotaryEmb
|
|
q = q.view(batch_size, seq_len, self.num_heads_per_rank, self.head_dim)
|
|
k = k.view(batch_size, seq_len, self.num_kv_heads_per_rank, self.head_dim)
|
|
v = v.view(batch_size, seq_len, self.num_kv_heads_per_rank, self.head_dim)
|
|
|
|
# Apply rotary position embeddings using vLLM's ApplyRotaryEmb
|
|
# ApplyRotaryEmb expects x: [batch, seq, heads, head_dim]
|
|
# cos/sin: [seq_len, rotary_dim/2]
|
|
q[..., : self.rotary_dim] = self.apply_rotary_emb(
|
|
q[..., : self.rotary_dim], rotary_pos_emb_cos, rotary_pos_emb_sin
|
|
)
|
|
k[..., : self.rotary_dim] = self.apply_rotary_emb(
|
|
k[..., : self.rotary_dim], rotary_pos_emb_cos, rotary_pos_emb_sin
|
|
)
|
|
|
|
# MMEncoderAttention expects [batch, seq, num_heads, head_dim]
|
|
# It handles GQA internally via repeat_interleave
|
|
attn_output = self.attn(q, k, v)
|
|
|
|
# Reshape back to [batch, seq, hidden_size]
|
|
attn_output = attn_output.view(batch_size, seq_len, -1)
|
|
|
|
# Output projection
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class GlmAsrEncoderMLP(nn.Module):
|
|
"""
|
|
Optimized MLP for GLM-ASR encoder.
|
|
Uses vLLM's parallel linear layers for better performance.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.intermediate_size = config.intermediate_size
|
|
|
|
self.fc1 = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.intermediate_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.fc1",
|
|
)
|
|
|
|
self.act_fn = get_act_fn(config.hidden_act)
|
|
|
|
self.fc2 = RowParallelLinear(
|
|
self.intermediate_size,
|
|
self.hidden_size,
|
|
bias=True,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.fc2",
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
hidden_states, _ = self.fc1(hidden_states)
|
|
hidden_states = self.act_fn(hidden_states)
|
|
hidden_states, _ = self.fc2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class GlmAsrEncoderLayer(nn.Module):
|
|
"""
|
|
Optimized Transformer encoder layer for GLM-ASR.
|
|
Combines attention and MLP with residual connections and layer norms.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.self_attn = GlmAsrEncoderAttention(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
|
|
self.mlp = GlmAsrEncoderMLP(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
|
|
layer_norm_eps = getattr(config, "layer_norm_eps", 1e-5)
|
|
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=layer_norm_eps)
|
|
self.post_attention_layernorm = nn.LayerNorm(
|
|
self.hidden_size, eps=layer_norm_eps
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
rotary_pos_emb_cos: torch.Tensor,
|
|
rotary_pos_emb_sin: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
hidden_states: [batch_size, seq_len, hidden_size]
|
|
rotary_pos_emb_cos: [seq_len, rotary_dim/2] - cosine of rotary embeddings
|
|
rotary_pos_emb_sin: [seq_len, rotary_dim/2] - sine of rotary embeddings
|
|
|
|
Returns:
|
|
[batch_size, seq_len, hidden_size]
|
|
"""
|
|
# Self-attention with residual
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
hidden_states = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
|
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
# MLP with residual
|
|
residual = hidden_states
|
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
hidden_states = self.mlp(hidden_states)
|
|
hidden_states = residual + hidden_states
|
|
|
|
return hidden_states
|
|
|
|
|
|
class _GlmAsrEncoderOutput:
|
|
"""
|
|
Simple output container compatible with transformers' BaseModelOutput.
|
|
|
|
This lightweight container holds the encoder output and is compatible
|
|
with the transformers library's output format while being more efficient
|
|
than a full dataclass.
|
|
|
|
Attributes:
|
|
last_hidden_state: Final layer hidden states from the encoder.
|
|
Shape: [batch_size, seq_len, hidden_size]
|
|
"""
|
|
|
|
__slots__ = ("last_hidden_state",)
|
|
|
|
def __init__(self, last_hidden_state: torch.Tensor):
|
|
self.last_hidden_state = last_hidden_state
|
|
|
|
|
|
class GlmAsrEncoder(nn.Module):
|
|
"""
|
|
Optimized GLM-ASR Audio Encoder with vLLM native implementation.
|
|
|
|
This encoder processes audio features through convolutional layers
|
|
followed by transformer layers with rotary position embeddings.
|
|
Optimized for performance with:
|
|
- QKVParallelLinear for fused attention projections
|
|
- Tensor parallelism support via ColumnParallelLinear/RowParallelLinear
|
|
- Quantization support
|
|
- Flash Attention (SDPA)
|
|
"""
|
|
|
|
# Mapping for weight loading: transformers uses separate q/k/v, we use fused qkv
|
|
packed_modules_mapping = {
|
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
# Convolutional feature extraction layers
|
|
self.conv1 = nn.Conv1d(
|
|
config.num_mel_bins,
|
|
config.hidden_size,
|
|
kernel_size=3,
|
|
padding=1,
|
|
)
|
|
self.conv2 = nn.Conv1d(
|
|
config.hidden_size,
|
|
config.hidden_size,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
)
|
|
|
|
# Transformer encoder layers
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
GlmAsrEncoderLayer(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.layers.{layer_idx}",
|
|
)
|
|
for layer_idx in range(config.num_hidden_layers)
|
|
]
|
|
)
|
|
|
|
# Final layer norm
|
|
layer_norm_eps = getattr(config, "layer_norm_eps", 1e-5)
|
|
self.norm = nn.LayerNorm(config.hidden_size, eps=layer_norm_eps)
|
|
|
|
# Rotary position embeddings
|
|
self.rotary_emb = GlmAsrEncoderRotaryEmbedding(config)
|
|
|
|
def _get_feat_extract_output_lengths(
|
|
self, input_lengths: torch.Tensor
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Compute the output length after convolutions.
|
|
|
|
Args:
|
|
input_lengths: Input sequence lengths [batch_size]
|
|
|
|
Returns:
|
|
Tuple of (output after conv1, output after conv2)
|
|
"""
|
|
# Conv1: kernel=3, stride=1, padding=1
|
|
output_lengths_conv1 = (input_lengths + 2 * 1 - 3) // 1 + 1
|
|
|
|
# Conv2: kernel=3, stride=2, padding=1
|
|
output_lengths_conv2 = (output_lengths_conv1 + 2 * 1 - 3) // 2 + 1
|
|
|
|
return output_lengths_conv1, output_lengths_conv2
|
|
|
|
def forward(self, input_features: torch.Tensor) -> _GlmAsrEncoderOutput:
|
|
"""
|
|
Forward pass through the encoder.
|
|
|
|
Args:
|
|
input_features: [batch_size, num_mel_bins, seq_len]
|
|
|
|
Returns:
|
|
_GlmAsrEncoderOutput: Object with .last_hidden_state attribute \
|
|
containing [batch_size, seq_len', hidden_size] where seq_len' \
|
|
is the sequence length after convolutions
|
|
"""
|
|
# Apply convolutional layers with GELU activation
|
|
hidden_states = torch.nn.functional.gelu(self.conv1(input_features))
|
|
hidden_states = torch.nn.functional.gelu(self.conv2(hidden_states))
|
|
|
|
# Transpose to [batch_size, seq_len, hidden_size]
|
|
hidden_states = hidden_states.transpose(1, 2)
|
|
output_seq_len = hidden_states.shape[1]
|
|
|
|
# Compute rotary position embeddings on-demand
|
|
rotary_pos_emb = self.rotary_emb(output_seq_len)
|
|
rotary_pos_emb_cos = rotary_pos_emb.cos().to(dtype=hidden_states.dtype)
|
|
rotary_pos_emb_sin = rotary_pos_emb.sin().to(dtype=hidden_states.dtype)
|
|
|
|
# Apply transformer layers
|
|
for encoder_layer in self.layers:
|
|
hidden_states = encoder_layer(
|
|
hidden_states, rotary_pos_emb_cos, rotary_pos_emb_sin
|
|
)
|
|
|
|
# Final layer norm
|
|
hidden_states = self.norm(hidden_states)
|
|
|
|
# Return in a format compatible with transformers' BaseModelOutput
|
|
return _GlmAsrEncoderOutput(last_hidden_state=hidden_states)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
"""Custom weight loading to handle q_proj/k_proj/v_proj -> qkv_proj mapping."""
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
|
|
for name, loaded_weight in weights:
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Default weight loading for non-stacked params
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
class GlmAsrFeatureInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- num_chunks: Number of audio chunks (flattened)
|
|
- nmb: Number of mel bins
|
|
- num_audios: Number of original audio files
|
|
"""
|
|
|
|
type: Literal["audio_features"]
|
|
input_features: Annotated[
|
|
torch.Tensor | list[torch.Tensor],
|
|
TensorShape("num_chunks", "nmb", "chunk_length", dynamic_dims={"chunk_length"}),
|
|
]
|
|
feature_attention_mask: Annotated[
|
|
torch.Tensor | list[torch.Tensor],
|
|
TensorShape("num_chunks", "chunk_length", dynamic_dims={"chunk_length"}),
|
|
]
|
|
chunk_counts: Annotated[
|
|
torch.Tensor | list[torch.Tensor],
|
|
TensorShape("num_audios"),
|
|
]
|
|
|
|
|
|
class GlmAsrEmbeddingInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- bn: Batch size
|
|
- naf: Number of audio features
|
|
- hs: Hidden size (must match the hidden size of language model
|
|
backbone)
|
|
"""
|
|
|
|
type: Literal["audio_embeds"] = "audio_embeds"
|
|
audio_embeds: Annotated[
|
|
list[torch.Tensor],
|
|
TensorShape("bn", "naf", "hs", dynamic_dims={"naf"}),
|
|
]
|
|
|
|
|
|
GlmAsrInputs: TypeAlias = GlmAsrFeatureInputs | GlmAsrEmbeddingInputs
|
|
|
|
|
|
class GlmAsrMultiModalProjector(nn.Module):
|
|
"""
|
|
Projects audio encoder outputs to language model hidden space.
|
|
|
|
This projector uses a two-layer MLP to map audio features from the
|
|
encoder's intermediate size to the language model's hidden size.
|
|
Uses vLLM's parallel linear layers for tensor parallelism support.
|
|
|
|
Architecture:
|
|
- Linear layer: intermediate_size -> hidden_size * 2
|
|
- Activation function (e.g., GELU)
|
|
- Linear layer: hidden_size * 2 -> hidden_size
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: GlmAsrConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.linear_1 = ColumnParallelLinear(
|
|
input_size=config.audio_config.intermediate_size,
|
|
output_size=config.text_config.hidden_size * 2,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.linear_1",
|
|
)
|
|
self.act = get_act_fn(config.projector_hidden_act)
|
|
self.linear_2 = RowParallelLinear(
|
|
input_size=config.text_config.hidden_size * 2,
|
|
output_size=config.text_config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.linear_2",
|
|
)
|
|
|
|
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
|
hidden_states, _ = self.linear_1(audio_features)
|
|
hidden_states = self.act(hidden_states)
|
|
hidden_states, _ = self.linear_2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
def _glmasr_field_config(
|
|
hf_inputs: Mapping[str, torch.Tensor],
|
|
) -> dict[str, MultiModalFieldConfig]:
|
|
"""
|
|
Configure multimodal field batching strategy for GLM-ASR.
|
|
|
|
Determines how to batch audio inputs based on whether chunking is used.
|
|
When chunk_counts is present, features are flattened across chunks;
|
|
otherwise, they are batched normally.
|
|
|
|
Args:
|
|
hf_inputs: Dictionary of preprocessed inputs from HuggingFace processor.
|
|
|
|
Returns:
|
|
Dictionary mapping field names to MultiModalFieldConfig objects \
|
|
that specify batching behavior.
|
|
"""
|
|
chunk_counts = hf_inputs.get("chunk_counts")
|
|
if chunk_counts is not None:
|
|
return dict(
|
|
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
|
input_features=MultiModalFieldConfig.flat_from_sizes(
|
|
"audio", chunk_counts, dim=0
|
|
),
|
|
feature_attention_mask=MultiModalFieldConfig.flat_from_sizes(
|
|
"audio", chunk_counts, dim=0
|
|
),
|
|
chunk_counts=MultiModalFieldConfig.batched("audio"),
|
|
)
|
|
return dict(
|
|
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
|
input_features=MultiModalFieldConfig.batched("audio"),
|
|
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
|
chunk_counts=MultiModalFieldConfig.batched("audio"),
|
|
)
|
|
|
|
|
|
class GlmAsrMultiModalDataParser(MultiModalDataParser):
|
|
"""
|
|
Custom parser for GLM-ASR multimodal data.
|
|
|
|
Extends the base parser to handle GLM-ASR specific audio data formats,
|
|
including both pre-computed audio embeddings and raw audio features.
|
|
"""
|
|
|
|
def _parse_audio_data(
|
|
self,
|
|
data: dict[str, torch.Tensor] | ModalityData[Any],
|
|
) -> ModalityDataItems[Any, Any] | None:
|
|
if isinstance(data, dict):
|
|
return DictEmbeddingItems(
|
|
data,
|
|
modality="audio",
|
|
required_fields={"audio_embeds"},
|
|
fields_factory=_glmasr_field_config,
|
|
)
|
|
return super()._parse_audio_data(data)
|
|
|
|
|
|
class GlmAsrProcessingInfo(BaseProcessingInfo):
|
|
"""
|
|
Processing information provider for GLM-ASR model.
|
|
|
|
Provides access to model configuration, processor, and feature extractor
|
|
needed for audio preprocessing and multimodal integration.
|
|
"""
|
|
|
|
def get_hf_config(self) -> GlmAsrConfig:
|
|
return self.ctx.get_hf_config(GlmAsrConfig)
|
|
|
|
def get_hf_processor(self, **kwargs: object) -> GlmAsrProcessor:
|
|
return self.ctx.get_hf_processor(GlmAsrProcessor, **kwargs)
|
|
|
|
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
|
|
return self.get_hf_processor(**kwargs).feature_extractor
|
|
|
|
def get_data_parser(self):
|
|
feature_extractor = self.get_feature_extractor()
|
|
|
|
return GlmAsrMultiModalDataParser(
|
|
target_sr=feature_extractor.sampling_rate,
|
|
expected_hidden_size=self._get_expected_hidden_size(),
|
|
)
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
return {"audio": None}
|
|
|
|
|
|
class GlmAsrDummyInputsBuilder(BaseDummyInputsBuilder[GlmAsrProcessingInfo]):
|
|
"""
|
|
Builder for dummy inputs used in profiling and testing.
|
|
|
|
Generates dummy text prompts and audio data that match the expected
|
|
format for GLM-ASR model inputs. Used for memory profiling and
|
|
performance benchmarking.
|
|
"""
|
|
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
num_audios = mm_counts.get("audio", 0)
|
|
hf_processor = self.info.get_hf_processor()
|
|
return hf_processor.audio_token * num_audios
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
mm_options: Mapping[str, BaseDummyOptions],
|
|
) -> MultiModalDataDict:
|
|
feature_extractor = self.info.get_feature_extractor()
|
|
sampling_rate = feature_extractor.sampling_rate
|
|
num_audios = mm_counts.get("audio", 0)
|
|
audio_overrides = mm_options.get("audio")
|
|
|
|
max_audio_len = getattr(
|
|
self.info.get_hf_processor(), "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S
|
|
)
|
|
audio_len = int(max_audio_len * sampling_rate)
|
|
|
|
return {
|
|
"audio": self._get_dummy_audios(
|
|
length=audio_len,
|
|
num_audios=num_audios,
|
|
overrides=audio_overrides,
|
|
)
|
|
}
|
|
|
|
|
|
class GlmAsrMultiModalProcessor(BaseMultiModalProcessor["GlmAsrProcessingInfo"]):
|
|
"""
|
|
GLM-ASR processor that inherits directly from BaseMultiModalProcessor
|
|
for better performance and cleaner implementation.
|
|
"""
|
|
|
|
def _calculate_chunk_counts(
|
|
self,
|
|
audio_list: list[Any],
|
|
feature_extractor: WhisperFeatureExtractor,
|
|
processor: GlmAsrProcessor,
|
|
) -> list[int]:
|
|
sampling_rate = feature_extractor.sampling_rate
|
|
chunk_length = feature_extractor.chunk_length
|
|
max_audio_len = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S)
|
|
window_size = int(sampling_rate * chunk_length)
|
|
max_windows = int(max_audio_len // chunk_length)
|
|
|
|
chunk_counts = []
|
|
for audio in audio_list:
|
|
n_samples = len(audio) if isinstance(audio, list) else audio.shape[0]
|
|
n_chunks = max(1, (n_samples + window_size - 1) // window_size)
|
|
chunk_counts.append(min(n_chunks, max_windows))
|
|
return chunk_counts
|
|
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: dict[str, object],
|
|
mm_kwargs: Mapping[str, Any],
|
|
tok_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
# Normalize input: handle deprecated key and list conversion.
|
|
if "audios" in mm_data:
|
|
mm_data["audio"] = mm_data.pop("audios")
|
|
|
|
audio = mm_data.get("audio", [])
|
|
audio_list = [audio] if audio and not isinstance(audio, list) else audio
|
|
|
|
# Early return for text-only.
|
|
if not audio_list:
|
|
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
|
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
|
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
|
|
|
# Handle sampling_rate
|
|
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
|
mm_kwargs = dict(
|
|
**mm_kwargs,
|
|
sampling_rate=feature_extractor.sampling_rate,
|
|
)
|
|
|
|
# Call parent method
|
|
outputs = super()._call_hf_processor(
|
|
prompt=prompt,
|
|
mm_data=mm_data,
|
|
mm_kwargs=mm_kwargs,
|
|
tok_kwargs=tok_kwargs,
|
|
)
|
|
|
|
# Postprocess: rename mask and add chunk counts
|
|
# Handle different key names from different transformers versions
|
|
if "input_features_mask" in outputs:
|
|
outputs["feature_attention_mask"] = outputs.pop("input_features_mask")
|
|
elif "input_features_mask" not in outputs and "input_features" in outputs:
|
|
# If no mask is provided, create one from input_features
|
|
input_features = outputs["input_features"]
|
|
if isinstance(input_features, torch.Tensor):
|
|
# Create a mask of all ones matching the sequence length
|
|
mask = torch.ones(
|
|
input_features.shape[0],
|
|
input_features.shape[-1],
|
|
dtype=torch.long,
|
|
)
|
|
outputs["feature_attention_mask"] = mask
|
|
|
|
# Get processor for chunk counts calculation
|
|
processor = self.info.get_hf_processor(**mm_kwargs)
|
|
|
|
# Override chunk counts calculation with GLM-ASR specific logic
|
|
chunk_counts = self._calculate_chunk_counts(
|
|
audio_list, processor.feature_extractor, processor
|
|
)
|
|
outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long)
|
|
|
|
return outputs
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return _glmasr_field_config(hf_inputs)
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> Sequence[PromptUpdate]:
|
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
tokenizer = self.info.get_tokenizer()
|
|
vocab = tokenizer.get_vocab()
|
|
config = self.info.get_hf_config()
|
|
|
|
audio_token = getattr(processor, "audio_token", "<|pad|>")
|
|
audio_token_id = vocab.get(audio_token)
|
|
if audio_token_id is None:
|
|
audio_token_id = processor.audio_token_id
|
|
|
|
merge_factor = getattr(config, "merge_factor", DEFAULT_MERGE_FACTOR)
|
|
conv_params = getattr(config, "conv_params", DEFAULT_CONV_PARAMS)
|
|
out_mm_data = out_mm_kwargs.get_data()
|
|
feature_attention_mask = out_mm_data.get("feature_attention_mask")
|
|
chunk_counts = out_mm_data.get("chunk_counts")
|
|
|
|
# Pre-compute audio output lengths if feature_attention_mask is available
|
|
audio_output_lengths: list[int] = []
|
|
if feature_attention_mask is not None:
|
|
# Compute output lengths for all audio items
|
|
from .glmasr_utils import (
|
|
_as_list_chunk_counts,
|
|
_get_audio_output_lengths_from_mask,
|
|
)
|
|
|
|
if chunk_counts is not None:
|
|
start_idx = 0
|
|
for count in _as_list_chunk_counts(chunk_counts):
|
|
end_idx = start_idx + count
|
|
mask = feature_attention_mask[start_idx:end_idx]
|
|
if isinstance(mask, list):
|
|
mask = torch.stack(mask)
|
|
|
|
lengths = _get_audio_output_lengths_from_mask(
|
|
mask, merge_factor, conv_params
|
|
)
|
|
audio_output_lengths.append(int(lengths.sum().item()))
|
|
start_idx = end_idx
|
|
else:
|
|
# Single chunk per audio
|
|
for idx in range(len(feature_attention_mask)):
|
|
mask = feature_attention_mask[idx : idx + 1]
|
|
if isinstance(mask, list):
|
|
mask = torch.tensor(mask).unsqueeze(0)
|
|
lengths = _get_audio_output_lengths_from_mask(
|
|
mask, merge_factor, conv_params
|
|
)
|
|
audio_output_lengths.append(int(lengths.sum().item()))
|
|
|
|
def get_replacement_glmasr(item_idx: int):
|
|
# Use pre-computed lengths if available, otherwise fall back to audio_embeds
|
|
if audio_output_lengths:
|
|
num_features = audio_output_lengths[item_idx]
|
|
else:
|
|
audio_embeds = out_mm_data.get("audio_embeds")
|
|
if audio_embeds is not None:
|
|
embed = audio_embeds[item_idx]
|
|
num_features = embed.shape[0]
|
|
else:
|
|
raise ValueError(
|
|
"Either feature_attention_mask or audio_embeds must be provided"
|
|
)
|
|
|
|
if num_features == 0:
|
|
raise ValueError("Audio is too short")
|
|
|
|
audio_tokens = [audio_token_id] * int(num_features)
|
|
return PromptUpdateDetails.select_token_id(
|
|
audio_tokens,
|
|
embed_token_id=audio_token_id,
|
|
)
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="audio",
|
|
target=audio_token,
|
|
replacement=get_replacement_glmasr,
|
|
)
|
|
]
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
GlmAsrMultiModalProcessor,
|
|
info=GlmAsrProcessingInfo,
|
|
dummy_inputs=GlmAsrDummyInputsBuilder,
|
|
)
|
|
class GlmAsrForConditionalGeneration(
|
|
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
|
|
):
|
|
supported_languages = ISO639_1_SUPPORTED_LANGS
|
|
|
|
packed_modules_mapping = {
|
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
}
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
self.quant_config = quant_config
|
|
|
|
with self._mark_tower_model(vllm_config, "audio"):
|
|
self.audio_tower = GlmAsrEncoder(
|
|
config.audio_config,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "audio_tower"),
|
|
)
|
|
self.multi_modal_projector = GlmAsrMultiModalProjector(
|
|
config,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
|
)
|
|
|
|
with self._mark_language_model(vllm_config):
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
architectures=["LlamaForCausalLM"],
|
|
)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors
|
|
)
|
|
|
|
@classmethod
|
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
|
if modality.startswith("audio"):
|
|
return "<|begin_of_audio|><|pad|><|end_of_audio|>"
|
|
|
|
raise ValueError("Only audio modality is supported")
|
|
|
|
def get_mm_mapping(self) -> MultiModelKeys:
|
|
return MultiModelKeys.from_string_field(
|
|
language_model="language_model.",
|
|
connector="multi_modal_projector.",
|
|
tower_model="audio_tower.",
|
|
)
|
|
|
|
def _parse_and_validate_audio_input(self, **kwargs: object) -> GlmAsrInputs | None:
|
|
audio_embeds = kwargs.pop("audio_embeds", None)
|
|
if audio_embeds is not None:
|
|
return GlmAsrEmbeddingInputs(type="audio_embeds", audio_embeds=audio_embeds)
|
|
|
|
input_features = kwargs.pop("input_features", None)
|
|
if input_features is None:
|
|
return None
|
|
|
|
return GlmAsrFeatureInputs(
|
|
type="audio_features",
|
|
input_features=input_features,
|
|
feature_attention_mask=kwargs.pop("feature_attention_mask", None),
|
|
chunk_counts=kwargs.pop("chunk_counts", None),
|
|
)
|
|
|
|
def _process_audio_input(
|
|
self, audio_input: GlmAsrInputs
|
|
) -> torch.Tensor | tuple[torch.Tensor, ...]:
|
|
if audio_input["type"] == "audio_embeds":
|
|
return tuple(audio_input["audio_embeds"])
|
|
|
|
input_features = audio_input["input_features"]
|
|
feature_attention_mask = audio_input["feature_attention_mask"]
|
|
|
|
if isinstance(input_features, list):
|
|
input_features = torch.cat(input_features, dim=0)
|
|
feature_attention_mask = torch.cat(feature_attention_mask, dim=0)
|
|
|
|
num_chunks = input_features.shape[0]
|
|
chunk_counts = _normalize_chunk_counts(
|
|
audio_input.get("chunk_counts"), num_chunks=num_chunks
|
|
)
|
|
|
|
# Convert input_features to model dtype (e.g., bfloat16) to match model weights
|
|
input_features = input_features.to(dtype=self.audio_tower.conv1.weight.dtype)
|
|
|
|
# audio_tower returns [batch_size, seq_len, hidden_size] where hidden_size=1280
|
|
audio_hidden_states = self.audio_tower(input_features).last_hidden_state
|
|
|
|
# GLM-ASR merges consecutive frames: 4 frames with hidden_size=1280
|
|
# -> 1 frame with intermediate_size=5120
|
|
hidden_size = self.config.audio_config.hidden_size
|
|
intermediate_size = self.config.audio_config.intermediate_size
|
|
merge_ratio = intermediate_size // hidden_size
|
|
|
|
# Truncate sequence length to be divisible by merge_ratio
|
|
seq_len = audio_hidden_states.shape[1]
|
|
seq_len_truncated = (seq_len // merge_ratio) * merge_ratio
|
|
if seq_len_truncated < seq_len:
|
|
audio_hidden_states = audio_hidden_states[:, :seq_len_truncated, :]
|
|
|
|
# Reshape to merge consecutive frames
|
|
audio_hidden_states = audio_hidden_states.reshape(
|
|
num_chunks,
|
|
-1,
|
|
intermediate_size,
|
|
)
|
|
|
|
audio_features = self.multi_modal_projector(audio_hidden_states)
|
|
|
|
merge_factor = getattr(self.config, "merge_factor", DEFAULT_MERGE_FACTOR)
|
|
conv_params = getattr(self.config, "conv_params", DEFAULT_CONV_PARAMS)
|
|
|
|
audio_output_lengths = _get_audio_output_lengths_for_tower(
|
|
self.audio_tower,
|
|
feature_attention_mask.sum(-1),
|
|
merge_factor,
|
|
conv_params,
|
|
)
|
|
|
|
masked_audio_features = _flatten_audio_features_by_length(
|
|
audio_features, audio_output_lengths
|
|
)
|
|
|
|
chunk_embeddings = torch.split(
|
|
masked_audio_features, audio_output_lengths.flatten().tolist()
|
|
)
|
|
return _group_audio_embeddings(chunk_embeddings, chunk_counts)
|
|
|
|
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
|
if audio_input is None:
|
|
return []
|
|
|
|
masked_audio_features = self._process_audio_input(audio_input)
|
|
|
|
return masked_audio_features
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs: object,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
hidden_states = self.language_model.model(
|
|
input_ids,
|
|
positions,
|
|
intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor | None:
|
|
return self.language_model.compute_logits(hidden_states)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
skip_prefixes = ["audio_tower.embed_positions"]
|
|
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
|
return loader.load_weights(weights)
|
|
|
|
@classmethod
|
|
def _get_audio_token(cls, model_config: ModelConfig) -> str:
|
|
"""Get the audio token from processor.
|
|
|
|
Similar to get_placeholder_str but returns single token.
|
|
"""
|
|
processor = cached_processor_from_config(model_config)
|
|
return getattr(processor, "audio_token", "<|pad|>")
|
|
|
|
@classmethod
|
|
def get_speech_to_text_config(
|
|
cls, model_config: ModelConfig, task_type: str
|
|
) -> SpeechToTextConfig:
|
|
processor = cached_processor_from_config(model_config)
|
|
feature_extractor = processor.feature_extractor
|
|
max_audio_clip_s = getattr(processor, "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S)
|
|
return SpeechToTextConfig(
|
|
max_audio_clip_s=max_audio_clip_s,
|
|
sample_rate=feature_extractor.sampling_rate,
|
|
)
|
|
|
|
@classmethod
|
|
def get_generation_prompt(
|
|
cls,
|
|
audio: np.ndarray,
|
|
model_config: ModelConfig,
|
|
stt_config: SpeechToTextConfig,
|
|
language: str | None,
|
|
task_type: Literal["transcribe", "translate"],
|
|
request_prompt: str,
|
|
to_language: str | None,
|
|
) -> PromptType:
|
|
"""Get the generation prompt to be used for transcription requests."""
|
|
tokenizer = cached_tokenizer_from_config(model_config)
|
|
audio_token = cls._get_audio_token(model_config)
|
|
|
|
if task_type == "translate":
|
|
full_lang_name_to = cls.supported_languages.get(to_language, to_language)
|
|
user_content = f"{audio_token}translate the speech to {full_lang_name_to}"
|
|
elif task_type == "transcribe":
|
|
user_content = (
|
|
f"{audio_token}can you transcribe the speech into a written format?"
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported task type {task_type}")
|
|
|
|
messages = [{"role": "user", "content": user_content}]
|
|
prompt = tokenizer.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
|
|
prompt_token_ids = tokenizer.encode(prompt)
|
|
|
|
return TokensPrompt(
|
|
prompt_token_ids=prompt_token_ids,
|
|
multi_modal_data={"audio": audio},
|
|
)
|