Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Roger Wang <hey@rogerw.io>
1009 lines
32 KiB
Python
1009 lines
32 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import math
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from typing import Annotated, Literal, cast
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from transformers import (
|
|
BatchFeature,
|
|
Qwen3Config,
|
|
)
|
|
|
|
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
|
from vllm.config.multimodal import BaseDummyOptions
|
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
from vllm.inputs.data import PromptType
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
|
from vllm.model_executor.layers.attention.mm_encoder_attention import (
|
|
MMEncoderAttention,
|
|
)
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models.whisper_utils import (
|
|
ISO639_1_SUPPORTED_LANGS,
|
|
)
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalDataDict,
|
|
MultiModalFieldConfig,
|
|
MultiModalKwargsItems,
|
|
)
|
|
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
|
|
from vllm.multimodal.processing import (
|
|
BaseDummyInputsBuilder,
|
|
BaseMultiModalProcessor,
|
|
BaseProcessingInfo,
|
|
PromptReplacement,
|
|
PromptUpdate,
|
|
)
|
|
from vllm.transformers_utils.processor import cached_processor_from_config
|
|
from vllm.transformers_utils.processors.funasr import FunASRFeatureExtractor
|
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
|
|
|
from .interfaces import (
|
|
MultiModalEmbeddings,
|
|
SupportsMultiModal,
|
|
SupportsTranscription,
|
|
_require_is_multimodal,
|
|
)
|
|
from .qwen3 import Qwen3Model
|
|
from .utils import (
|
|
AutoWeightsLoader,
|
|
WeightsMapper,
|
|
_merge_multimodal_embeddings,
|
|
maybe_prefix,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
|
if maxlen is None:
|
|
maxlen = lengths.max()
|
|
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
|
|
matrix = torch.unsqueeze(lengths, dim=-1)
|
|
mask = row_vector < matrix
|
|
mask = mask.detach()
|
|
|
|
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
|
|
|
|
|
|
class LayerNorm(torch.nn.LayerNorm):
|
|
def __init__(self, nout, dim=-1):
|
|
super().__init__(nout, eps=1e-12)
|
|
self.dim = dim
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if self.dim == -1:
|
|
return super().forward(x)
|
|
return super().forward(x.transpose(self.dim, -1)).transpose(self.dim, -1)
|
|
|
|
|
|
class EncoderLayerSANM(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_size: int,
|
|
size: int,
|
|
self_attn: nn.Module,
|
|
feed_forward: nn.Module,
|
|
normalize_before=True,
|
|
):
|
|
super().__init__()
|
|
self.self_attn = self_attn
|
|
self.feed_forward = feed_forward
|
|
self.norm1 = LayerNorm(in_size)
|
|
self.norm2 = LayerNorm(size)
|
|
self.in_size = in_size
|
|
self.size = size
|
|
self.normalize_before = normalize_before
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
mask: torch.Tensor | None = None,
|
|
cache=None,
|
|
mask_shift_chunk=None,
|
|
mask_att_chunk_encoder=None,
|
|
):
|
|
residual = hidden_states
|
|
hidden_states = self.norm1(hidden_states)
|
|
|
|
if self.in_size == self.size:
|
|
hidden_states = residual + self.self_attn(
|
|
hidden_states,
|
|
mask,
|
|
mask_shift_chunk=mask_shift_chunk,
|
|
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
|
)
|
|
else:
|
|
hidden_states = self.self_attn(
|
|
hidden_states,
|
|
mask,
|
|
mask_shift_chunk=mask_shift_chunk,
|
|
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
|
)
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.norm2(hidden_states)
|
|
hidden_states = residual + self.feed_forward(hidden_states)
|
|
|
|
return hidden_states, mask, cache, mask_shift_chunk, mask_att_chunk_encoder
|
|
|
|
|
|
class MultiHeadedAttentionSANM(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_head: int,
|
|
in_feat: int,
|
|
n_feat: int,
|
|
kernel_size: int,
|
|
sanm_shift: int = 0,
|
|
):
|
|
super().__init__()
|
|
assert n_feat % n_head == 0
|
|
# We assume d_v always equals d_k
|
|
self.d_k = n_feat // n_head
|
|
self.h = n_head
|
|
self.out_proj = ReplicatedLinear(
|
|
input_size=n_feat,
|
|
output_size=n_feat,
|
|
bias=True,
|
|
)
|
|
self.linear_q_k_v = ReplicatedLinear(
|
|
input_size=in_feat,
|
|
output_size=n_feat * 3,
|
|
bias=True,
|
|
)
|
|
self.attn = None
|
|
|
|
self.fsmn_block = nn.Conv1d(
|
|
n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
|
|
)
|
|
# padding
|
|
left_padding = (kernel_size - 1) // 2
|
|
if sanm_shift > 0:
|
|
left_padding = left_padding + sanm_shift
|
|
right_padding = kernel_size - 1 - left_padding
|
|
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
|
|
|
def forward_fsmn(
|
|
self,
|
|
inputs: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
mask_shift_chunk: torch.Tensor = None,
|
|
):
|
|
b, t, d = inputs.size()
|
|
if mask is not None:
|
|
mask = torch.reshape(mask, (b, -1, 1))
|
|
if mask_shift_chunk is not None:
|
|
mask = mask * mask_shift_chunk
|
|
inputs = inputs * mask
|
|
|
|
x = inputs.transpose(1, 2)
|
|
x = self.pad_fn(x)
|
|
x = self.fsmn_block(x)
|
|
x = x.transpose(1, 2)
|
|
x += inputs
|
|
if mask is not None:
|
|
x = x * mask
|
|
return x
|
|
|
|
def forward_qkv(self, x: torch.Tensor):
|
|
b, t, d = x.size()
|
|
|
|
q_k_v, _ = self.linear_q_k_v(x)
|
|
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
|
|
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)
|
|
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)
|
|
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)
|
|
|
|
return q_h, k_h, v_h, v
|
|
|
|
def forward_attention(
|
|
self,
|
|
value: torch.Tensor,
|
|
scores: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
mask_att_chunk_encoder: torch.Tensor = None,
|
|
):
|
|
n_batch = value.size(0)
|
|
if mask is not None:
|
|
if mask_att_chunk_encoder is not None:
|
|
mask = mask * mask_att_chunk_encoder
|
|
|
|
mask = mask.unsqueeze(1).eq(0)
|
|
|
|
min_value = -float("inf")
|
|
scores = scores.masked_fill(mask, min_value)
|
|
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
|
|
else:
|
|
attn = torch.softmax(scores, dim=-1)
|
|
|
|
p_attn = attn
|
|
x = torch.matmul(p_attn, value)
|
|
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
|
|
|
out, _ = self.out_proj(x)
|
|
return out
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
mask_shift_chunk: torch.Tensor = None,
|
|
mask_att_chunk_encoder: torch.Tensor = None,
|
|
):
|
|
q_h, k_h, v_h, v = self.forward_qkv(hidden_states)
|
|
fsmn_memory = self.forward_fsmn(v, mask, mask_shift_chunk)
|
|
q_h = q_h * self.d_k ** (-0.5)
|
|
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
|
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
|
|
return att_outs + fsmn_memory
|
|
|
|
|
|
class SinusoidalPositionEncoder(torch.nn.Module):
|
|
def __init__(self, d_model=80):
|
|
super().__init__()
|
|
|
|
def encode(
|
|
self,
|
|
positions: torch.Tensor = None,
|
|
depth: int = None,
|
|
dtype: torch.dtype = torch.float32,
|
|
):
|
|
batch_size = positions.size(0)
|
|
positions = positions.type(dtype)
|
|
device = positions.device
|
|
log_timescale_increment = torch.log(
|
|
torch.tensor([10000], dtype=dtype, device=device)
|
|
) / (depth / 2 - 1)
|
|
inv_timescales = torch.exp(
|
|
torch.arange(depth / 2, device=device).type(dtype)
|
|
* (-log_timescale_increment)
|
|
)
|
|
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
|
|
scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
|
|
inv_timescales, [1, 1, -1]
|
|
)
|
|
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
|
|
return encoding.type(dtype)
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
batch_size, timesteps, input_dim = hidden_states.size()
|
|
positions = torch.arange(1, timesteps + 1, device=hidden_states.device)[None, :]
|
|
position_encoding = self.encode(positions, input_dim, hidden_states.dtype).to(
|
|
hidden_states.device
|
|
)
|
|
|
|
return hidden_states + position_encoding
|
|
|
|
|
|
class SenseVoiceEncoderSmall(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
output_size: int = 256,
|
|
attention_heads: int = 4,
|
|
linear_units: int = 2048,
|
|
num_blocks: int = 6,
|
|
tp_blocks: int = 0,
|
|
attention_dropout_rate: float = 0.0,
|
|
normalize_before: bool = True,
|
|
kernel_size: int = 11,
|
|
sanm_shift: int = 0,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self._output_size = output_size
|
|
self.embed = SinusoidalPositionEncoder()
|
|
|
|
self.normalize_before = normalize_before
|
|
|
|
positionwise_layer = PositionwiseFeedForward
|
|
positionwise_layer_args = (
|
|
output_size,
|
|
linear_units,
|
|
)
|
|
|
|
encoder_selfattn_layer = MultiHeadedAttentionSANM
|
|
encoder_selfattn_layer_args0 = (
|
|
attention_heads,
|
|
input_size,
|
|
output_size,
|
|
kernel_size,
|
|
sanm_shift,
|
|
)
|
|
encoder_selfattn_layer_args = (
|
|
attention_heads,
|
|
output_size,
|
|
output_size,
|
|
kernel_size,
|
|
sanm_shift,
|
|
)
|
|
|
|
self.encoders0 = nn.ModuleList(
|
|
[
|
|
EncoderLayerSANM(
|
|
input_size,
|
|
output_size,
|
|
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
|
positionwise_layer(*positionwise_layer_args),
|
|
)
|
|
for i in range(1)
|
|
]
|
|
)
|
|
self.encoders = nn.ModuleList(
|
|
[
|
|
EncoderLayerSANM(
|
|
output_size,
|
|
output_size,
|
|
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
|
positionwise_layer(*positionwise_layer_args),
|
|
)
|
|
for i in range(num_blocks - 1)
|
|
]
|
|
)
|
|
|
|
self.tp_encoders = nn.ModuleList(
|
|
[
|
|
EncoderLayerSANM(
|
|
output_size,
|
|
output_size,
|
|
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
|
positionwise_layer(*positionwise_layer_args),
|
|
)
|
|
for i in range(tp_blocks)
|
|
]
|
|
)
|
|
|
|
self.after_norm = LayerNorm(output_size)
|
|
|
|
self.tp_norm = LayerNorm(output_size)
|
|
|
|
def output_size(self) -> int:
|
|
return self._output_size
|
|
|
|
def forward(
|
|
self,
|
|
xs_pad: torch.Tensor,
|
|
ilens: torch.Tensor,
|
|
):
|
|
maxlen = xs_pad.shape[1]
|
|
masks = sequence_mask(
|
|
ilens, maxlen=maxlen, dtype=ilens.dtype, device=ilens.device
|
|
)[:, None, :]
|
|
|
|
xs_pad *= self.output_size() ** 0.5
|
|
|
|
xs_pad = self.embed(xs_pad)
|
|
|
|
for layer_idx, encoder_layer in enumerate(self.encoders0):
|
|
encoder_outs = encoder_layer(xs_pad, masks)
|
|
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
|
|
|
for layer_idx, encoder_layer in enumerate(self.encoders):
|
|
encoder_outs = encoder_layer(xs_pad, masks)
|
|
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
|
|
|
xs_pad = self.after_norm(xs_pad)
|
|
|
|
olens = masks.squeeze(1).sum(1).int()
|
|
|
|
for layer_idx, encoder_layer in enumerate(self.tp_encoders):
|
|
encoder_outs = encoder_layer(xs_pad, masks)
|
|
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
|
|
|
xs_pad = self.tp_norm(xs_pad)
|
|
return xs_pad, olens
|
|
|
|
|
|
class PositionwiseFeedForward(nn.Module):
|
|
def __init__(self, idim: int, hidden_units: int):
|
|
super().__init__()
|
|
self.w_1 = ColumnParallelLinear(
|
|
input_size=idim,
|
|
output_size=hidden_units,
|
|
bias=True,
|
|
)
|
|
self.w_2 = RowParallelLinear(
|
|
input_size=hidden_units,
|
|
output_size=idim,
|
|
bias=True,
|
|
)
|
|
self.activation = _ACTIVATION_REGISTRY["relu"]
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
hidden_states, _ = self.w_1(hidden_states)
|
|
hidden_states = self.activation(hidden_states)
|
|
hidden_states, _ = self.w_2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class EncoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
self_attn: nn.Module,
|
|
feed_forward: nn.Module,
|
|
):
|
|
super().__init__()
|
|
self.self_attn = self_attn
|
|
self.feed_forward = feed_forward
|
|
self.norm1 = LayerNorm(size)
|
|
self.norm2 = LayerNorm(size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
residual = hidden_states
|
|
hidden_states = self.norm1(hidden_states)
|
|
hidden_states = residual + self.self_attn(hidden_states, None, None)
|
|
residual = hidden_states
|
|
hidden_states = self.norm2(hidden_states)
|
|
hidden_states = residual + self.feed_forward(hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class FunASRAudioAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
embed_dim: int,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = self.embed_dim // self.num_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
self.num_local_heads = self.num_heads // tp_size
|
|
|
|
if (self.head_dim * self.num_heads) != self.embed_dim:
|
|
raise ValueError(
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: "
|
|
f"{self.embed_dim} and `num_heads`: {self.num_heads})."
|
|
)
|
|
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
self.qkv = QKVParallelLinear(
|
|
hidden_size=self.embed_dim,
|
|
head_size=self.head_dim,
|
|
total_num_heads=self.num_heads,
|
|
total_num_kv_heads=self.num_heads,
|
|
bias=True,
|
|
prefix=f"{prefix}.qkv",
|
|
)
|
|
|
|
self.out_proj = RowParallelLinear(
|
|
input_size=self.embed_dim,
|
|
output_size=self.embed_dim,
|
|
bias=True,
|
|
prefix=f"{prefix}.out_proj",
|
|
)
|
|
|
|
self.attn = MMEncoderAttention(
|
|
num_heads=self.num_local_heads,
|
|
head_size=self.head_dim,
|
|
scale=self.scaling,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
max_seqlen: torch.Tensor | None,
|
|
) -> torch.Tensor:
|
|
bs, seq_length, _ = hidden_states.size()
|
|
qkv, _ = self.qkv(hidden_states)
|
|
q, k, v = qkv.chunk(3, dim=-1)
|
|
q = q.view(bs, seq_length, -1, self.head_dim)
|
|
k = k.view(bs, seq_length, -1, self.head_dim)
|
|
v = v.view(bs, seq_length, -1, self.head_dim)
|
|
|
|
attn_output = self.attn(
|
|
query=q,
|
|
key=k,
|
|
value=v,
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
)
|
|
|
|
attn_output = attn_output.view(bs, seq_length, -1)
|
|
output, _ = self.out_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class Transformer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
downsample_rate=2,
|
|
encoder_dim=1280,
|
|
llm_dim=4096,
|
|
ffn_dim: int = 2048,
|
|
prefix: str = "",
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.k = downsample_rate
|
|
self.encoder_dim = encoder_dim
|
|
self.llm_dim = llm_dim
|
|
self.linear1 = ColumnParallelLinear(
|
|
input_size=self.encoder_dim * self.k,
|
|
output_size=ffn_dim,
|
|
bias=True,
|
|
)
|
|
self.relu = nn.ReLU()
|
|
self.linear2 = RowParallelLinear(
|
|
input_size=ffn_dim,
|
|
output_size=self.llm_dim,
|
|
bias=True,
|
|
)
|
|
|
|
self.blocks = None
|
|
if kwargs.get("n_layer", 2) > 0:
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
EncoderLayer(
|
|
llm_dim,
|
|
FunASRAudioAttention(
|
|
kwargs.get("attention_heads", 8),
|
|
llm_dim,
|
|
prefix=f"{prefix}.self_attn",
|
|
),
|
|
PositionwiseFeedForward(
|
|
llm_dim,
|
|
llm_dim // 4,
|
|
),
|
|
)
|
|
for _ in range(kwargs.get("n_layer", 2))
|
|
]
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor, ilens: int = 0):
|
|
batch_size, seq_len, dim = hidden_states.size()
|
|
chunk_num = (seq_len - 1) // self.k + 1
|
|
pad_num = chunk_num * self.k - seq_len
|
|
hidden_states = F.pad(hidden_states, (0, 0, 0, pad_num, 0, 0), value=0.0)
|
|
seq_len = hidden_states.size(1)
|
|
|
|
hidden_states = hidden_states.contiguous()
|
|
hidden_states = hidden_states.view(batch_size, chunk_num, dim * self.k)
|
|
hidden_states, _ = self.linear1(hidden_states)
|
|
hidden_states = self.relu(hidden_states)
|
|
hidden_states, _ = self.linear2(hidden_states)
|
|
|
|
olens = None
|
|
olens = (ilens - 1) // self.k + 1
|
|
|
|
if self.blocks is not None:
|
|
for layer, block in enumerate(self.blocks):
|
|
hidden_states = block(hidden_states)
|
|
return hidden_states, olens
|
|
|
|
|
|
class FunASRAudioInputs(TensorSchema):
|
|
"""
|
|
Dimensions:
|
|
- b: Batch size
|
|
- nmb: Number of mel bins
|
|
- t: Time frames (M)
|
|
"""
|
|
|
|
input_features: Annotated[
|
|
list[torch.Tensor] | None,
|
|
TensorShape("b", "nmb", "t"),
|
|
]
|
|
speech_lengths: Annotated[
|
|
list[torch.Tensor] | None,
|
|
TensorShape("b"),
|
|
]
|
|
fake_token_lengths: Annotated[
|
|
list[torch.Tensor] | None,
|
|
TensorShape("b"),
|
|
]
|
|
|
|
|
|
class FunASREncoder(nn.Module):
|
|
def __init__(
|
|
self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
|
|
):
|
|
super().__init__()
|
|
self.audio_encoder = SenseVoiceEncoderSmall(
|
|
input_size=560, **vllm_config.model_config.hf_config.audio_encoder_conf
|
|
)
|
|
self.audio_adaptor = Transformer(
|
|
downsample_rate=1,
|
|
use_low_frame_rate=True,
|
|
ffn_dim=2048,
|
|
llm_dim=1024,
|
|
encoder_dim=512,
|
|
n_layer=2,
|
|
freeze=True,
|
|
prefix=maybe_prefix(prefix, "audio_encoder"),
|
|
)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
"""Load weights with mapping from HuggingFace format."""
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("self_attn.qkv.", "self_attn.q_proj.", "q"),
|
|
("self_attn.qkv.", "self_attn.k_proj.", "k"),
|
|
("self_attn.qkv.", "self_attn.v_proj.", "v"),
|
|
]
|
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
loaded_params: set[str] = set()
|
|
|
|
for name, loaded_weight in weights:
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
param = params_dict.get(name)
|
|
if param is not None:
|
|
weight_loader = getattr(
|
|
param, "weight_loader", default_weight_loader
|
|
)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
class FunASRModel(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
self.encoder = FunASREncoder(
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "encoder")
|
|
)
|
|
self.decoder = Qwen3Model(
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "decoder")
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None,
|
|
positions: torch.Tensor,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
decoder_outputs = self.decoder(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
return decoder_outputs
|
|
|
|
def get_encoder_outputs(
|
|
self,
|
|
speech: torch.Tensor | list[torch.Tensor] | None,
|
|
speech_lengths: torch.Tensor | list[torch.Tensor] | None,
|
|
) -> torch.Tensor | None:
|
|
self.feat_permute = False
|
|
|
|
if self.feat_permute:
|
|
encoder_out, encoder_out_lens = self.encoder.audio_encoder(
|
|
speech.permute(0, 2, 1), speech_lengths
|
|
)
|
|
else:
|
|
encoder_out, encoder_out_lens = self.encoder.audio_encoder(
|
|
speech, speech_lengths
|
|
)
|
|
|
|
encoder_out, encoder_out_lens = self.encoder.audio_adaptor(
|
|
encoder_out, encoder_out_lens
|
|
)
|
|
return encoder_out
|
|
|
|
|
|
class FunASRProcessingInfo(BaseProcessingInfo):
|
|
def get_hf_config(self) -> Qwen3Config:
|
|
return self.ctx.get_hf_config(Qwen3Config)
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
return {"audio": 1}
|
|
|
|
def get_feature_extractor(self, **kwargs: object) -> FunASRFeatureExtractor:
|
|
hf_processor = self.get_hf_processor(**kwargs)
|
|
feature_extractor = hf_processor.feature_extractor # type: ignore
|
|
assert isinstance(feature_extractor, FunASRFeatureExtractor)
|
|
return feature_extractor
|
|
|
|
def get_data_parser(self) -> MultiModalDataParser:
|
|
feature_extractor = self.get_feature_extractor()
|
|
return MultiModalDataParser(
|
|
target_sr=feature_extractor.sampling_rate,
|
|
target_channels=self.get_target_channels(),
|
|
)
|
|
|
|
def get_target_channels(self) -> int:
|
|
return 1
|
|
|
|
|
|
class FunASRDummyInputsBuilder(BaseDummyInputsBuilder[FunASRProcessingInfo]):
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
num_audios = mm_counts.get("audio", 0)
|
|
|
|
return "<|AUDIO|>" * num_audios
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
mm_options: Mapping[str, BaseDummyOptions],
|
|
) -> MultiModalDataDict:
|
|
feature_extractor = self.info.get_feature_extractor()
|
|
|
|
sampling_rate = feature_extractor.sampling_rate
|
|
audio_len = feature_extractor.chunk_length * sampling_rate
|
|
num_audios = mm_counts.get("audio", 0)
|
|
|
|
audio_overrides = mm_options.get("audio")
|
|
|
|
return {
|
|
"audio": self._get_dummy_audios(
|
|
length=audio_len,
|
|
num_audios=num_audios,
|
|
overrides=audio_overrides,
|
|
),
|
|
}
|
|
|
|
|
|
class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: Mapping[str, object],
|
|
mm_kwargs: Mapping[str, object],
|
|
tok_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
if mm_data:
|
|
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
|
mm_data = dict(audio=mm_data.pop("audios"))
|
|
mm_kwargs = dict(
|
|
**mm_kwargs,
|
|
sampling_rate=feature_extractor.sampling_rate,
|
|
)
|
|
processed_outputs = super()._call_hf_processor(
|
|
prompt=prompt,
|
|
mm_data=mm_data,
|
|
mm_kwargs=mm_kwargs,
|
|
tok_kwargs=tok_kwargs,
|
|
)
|
|
if "labels" in processed_outputs:
|
|
processed_outputs["input_ids"] = processed_outputs.pop("labels")
|
|
return processed_outputs
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return dict(
|
|
input_features=MultiModalFieldConfig.batched("audio"),
|
|
speech_lengths=MultiModalFieldConfig.batched("audio"),
|
|
fake_token_lengths=MultiModalFieldConfig.batched("audio"),
|
|
)
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> Sequence[PromptUpdate]:
|
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
|
audio_token_id = processor.audio_token_id
|
|
|
|
out_mm_data = out_mm_kwargs.get_data()
|
|
|
|
fake_token_lengths = out_mm_data.get("fake_token_lengths")
|
|
if fake_token_lengths is None:
|
|
audio_output_lengths = []
|
|
else:
|
|
assert isinstance(fake_token_lengths, torch.Tensor)
|
|
|
|
audio_output_lengths = fake_token_lengths.tolist()
|
|
|
|
def get_replacement_qwen2_audio(item_idx: int):
|
|
num_features = audio_output_lengths[item_idx]
|
|
return [audio_token_id] * num_features
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="audio",
|
|
target=[audio_token_id],
|
|
replacement=get_replacement_qwen2_audio,
|
|
)
|
|
]
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
FunASRMultiModalProcessor,
|
|
info=FunASRProcessingInfo,
|
|
dummy_inputs=FunASRDummyInputsBuilder,
|
|
)
|
|
class FunASRForConditionalGeneration(
|
|
nn.Module, SupportsTranscription, SupportsMultiModal
|
|
):
|
|
hf_to_vllm_mapper = WeightsMapper(
|
|
orig_to_new_substr={
|
|
"linear_q.": "q_proj.",
|
|
"linear_k.": "k_proj.",
|
|
"linear_v.": "v_proj.",
|
|
"linear_out.": "out_proj.",
|
|
"audio_adaptor.": "model.encoder.audio_adaptor.",
|
|
"audio_encoder.": "model.encoder.audio_encoder.",
|
|
"llm.model.": "model.decoder.",
|
|
"llm.lm_head": "lm_head",
|
|
}
|
|
)
|
|
|
|
supports_transcription_only = True
|
|
supports_segment_timestamp = True
|
|
supported_languages = ISO639_1_SUPPORTED_LANGS
|
|
|
|
@classmethod
|
|
def validate_language(cls, language: str | None) -> str | None:
|
|
if language is None:
|
|
# TODO language should be optional and can be guessed.
|
|
# For now we default to en. See
|
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
|
logger.warning(
|
|
"Defaulting to language='en'. If you wish to transcribe "
|
|
"audio in a different language, pass the `language` field "
|
|
"in the TranscriptionRequest."
|
|
)
|
|
language = "en"
|
|
return super().validate_language(language)
|
|
|
|
@classmethod
|
|
def get_generation_prompt(
|
|
cls,
|
|
audio: np.ndarray,
|
|
model_config: ModelConfig, # not needed here
|
|
stt_config: SpeechToTextConfig,
|
|
language: str | None,
|
|
task_type: Literal["transcribe", "translate"],
|
|
request_prompt: str,
|
|
to_language: str | None,
|
|
) -> PromptType:
|
|
if language is None:
|
|
raise ValueError(
|
|
"Language must be specified when creating the funasr prompt"
|
|
)
|
|
|
|
funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501
|
|
prompt = {
|
|
"prompt": funasr_prompt,
|
|
"multi_modal_data": {
|
|
"audio": (audio, stt_config.sample_rate),
|
|
},
|
|
}
|
|
return cast(PromptType, prompt)
|
|
|
|
@classmethod
|
|
def get_speech_to_text_config(
|
|
cls, model_config: ModelConfig, task_type: str
|
|
) -> SpeechToTextConfig:
|
|
processor = cached_processor_from_config(model_config)
|
|
|
|
return SpeechToTextConfig(
|
|
max_audio_clip_s=processor.feature_extractor.chunk_length,
|
|
sample_rate=processor.feature_extractor.sampling_rate,
|
|
)
|
|
|
|
@classmethod
|
|
def get_num_audio_tokens(
|
|
cls,
|
|
audio_duration_s: float,
|
|
stt_config: SpeechToTextConfig,
|
|
model_config: ModelConfig,
|
|
) -> int | None:
|
|
processor = cached_processor_from_config(model_config)
|
|
hop_length = processor.feature_extractor.hop_length
|
|
assert hop_length is not None
|
|
return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
self.dtype = vllm_config.model_config.dtype
|
|
|
|
self.model = FunASRModel(
|
|
vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"),
|
|
)
|
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
|
|
if config.tie_word_embeddings:
|
|
self.lm_head = self.model.decoder.embed_tokens
|
|
else:
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "lm_head"),
|
|
)
|
|
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
decoder_outputs = self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
return decoder_outputs
|
|
|
|
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
|
|
|
speech = audio_input["input_features"]
|
|
speech_lengths = audio_input["speech_lengths"]
|
|
enc_output = self.model.get_encoder_outputs(
|
|
speech=speech, speech_lengths=speech_lengths
|
|
)
|
|
|
|
return enc_output
|
|
|
|
def embed_input_ids(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
|
*,
|
|
is_multimodal: torch.Tensor | None = None,
|
|
handle_oov_mm_token: bool = False,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = self.model.decoder.embed_input_ids(input_ids)
|
|
|
|
return _merge_multimodal_embeddings(
|
|
inputs_embeds=inputs_embeds,
|
|
multimodal_embeddings=multimodal_embeddings,
|
|
is_multimodal=_require_is_multimodal(is_multimodal),
|
|
)
|
|
|
|
def _parse_and_validate_audio_input(self, **kwargs: object) -> FunASRAudioInputs:
|
|
input_features = kwargs.pop("input_features", None)
|
|
speech_lengths = kwargs.pop("speech_lengths", None)
|
|
fake_token_lengths = kwargs.pop("fake_token_lengths", None)
|
|
|
|
return FunASRAudioInputs(
|
|
input_features=input_features,
|
|
speech_lengths=speech_lengths,
|
|
fake_token_lengths=fake_token_lengths,
|
|
)
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
loader = AutoWeightsLoader(
|
|
self,
|
|
)
|
|
|
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|