adapt voxtral (#31095)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
b10d47e0e0
commit
3faa8bee57
@@ -1,9 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Annotated, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
@@ -16,7 +18,10 @@ from transformers import (
|
||||
)
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention.layer import Attention, AttentionType
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layers.cross_attention import CrossAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
@@ -34,6 +39,11 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.whisper_utils import (
|
||||
ISO639_1_SUPPORTED_LANGS,
|
||||
WhisperAttentionWithBlockPooling,
|
||||
WhisperCausalConv1d,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
@@ -64,67 +74,11 @@ from .utils import (
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
|
||||
|
||||
ISO639_1_SUPPORTED_LANGS = {
|
||||
"af": "Afrikaans",
|
||||
"ar": "Arabic",
|
||||
"hy": "Armenian",
|
||||
"az": "Azerbaijani",
|
||||
"be": "Belarusian",
|
||||
"bs": "Bosnian",
|
||||
"bg": "Bulgarian",
|
||||
"ca": "Catalan",
|
||||
"zh": "Chinese",
|
||||
"hr": "Croatian",
|
||||
"cs": "Czech",
|
||||
"da": "Danish",
|
||||
"nl": "Dutch",
|
||||
"en": "English",
|
||||
"et": "Estonian",
|
||||
"fi": "Finnish",
|
||||
"fr": "French",
|
||||
"gl": "Galician",
|
||||
"de": "German",
|
||||
"el": "Greek",
|
||||
"he": "Hebrew",
|
||||
"hi": "Hindi",
|
||||
"hu": "Hungarian",
|
||||
"is": "Icelandic",
|
||||
"id": "Indonesian",
|
||||
"it": "Italian",
|
||||
"ja": "Japanese",
|
||||
"kn": "Kannada",
|
||||
"kk": "Kazakh",
|
||||
"ko": "Korean",
|
||||
"lv": "Latvian",
|
||||
"lt": "Lithuanian",
|
||||
"mk": "Macedonian",
|
||||
"ms": "Malay",
|
||||
"mr": "Marathi",
|
||||
"mi": "Maori",
|
||||
"ne": "Nepali",
|
||||
"no": "Norwegian",
|
||||
"fa": "Persian",
|
||||
"pl": "Polish",
|
||||
"pt": "Portuguese",
|
||||
"ro": "Romanian",
|
||||
"ru": "Russian",
|
||||
"sr": "Serbian",
|
||||
"sk": "Slovak",
|
||||
"sl": "Slovenian",
|
||||
"es": "Spanish",
|
||||
"sw": "Swahili",
|
||||
"sv": "Swedish",
|
||||
"tl": "Tagalog",
|
||||
"ta": "Tamil",
|
||||
"th": "Thai",
|
||||
"tr": "Turkish",
|
||||
"uk": "Ukrainian",
|
||||
"ur": "Urdu",
|
||||
"vi": "Vietnamese",
|
||||
"cy": "Welsh",
|
||||
}
|
||||
class WhisperPosEmbedType(enum.Enum):
|
||||
SINUSOIDAL = "sinusoidal"
|
||||
NOPE = "nope"
|
||||
LEARNED = "learned"
|
||||
|
||||
|
||||
class WhisperAudioInputs(TensorSchema):
|
||||
@@ -184,6 +138,8 @@ class WhisperAttention(nn.Module):
|
||||
num_heads: int,
|
||||
bias: bool = True,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
block_pool_size: int = 1,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
@@ -242,7 +198,14 @@ class WhisperAttention(nn.Module):
|
||||
attn_type=self.attn_type,
|
||||
)
|
||||
else: # AttentionType.DECODER (regular decoder self-attention)
|
||||
self.attn = Attention(
|
||||
if block_pool_size > 1:
|
||||
attn_cls = partial(
|
||||
WhisperAttentionWithBlockPooling, block_pool_size=block_pool_size
|
||||
)
|
||||
else:
|
||||
attn_cls = Attention
|
||||
|
||||
self.attn = attn_cls(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
@@ -251,6 +214,7 @@ class WhisperAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=self.attn_type,
|
||||
per_layer_sliding_window=per_layer_sliding_window,
|
||||
)
|
||||
|
||||
def _init_qkv(
|
||||
@@ -386,6 +350,9 @@ class WhisperEncoderLayer(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
is_causal = getattr(config, "is_causal", False)
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
block_pool_size = getattr(config, "block_pool_size", 1)
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
@@ -393,7 +360,9 @@ class WhisperEncoderLayer(nn.Module):
|
||||
self.self_attn = WhisperAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
attn_type=AttentionType.ENCODER,
|
||||
attn_type=AttentionType.DECODER if is_causal else AttentionType.ENCODER,
|
||||
block_pool_size=block_pool_size,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
@@ -492,12 +461,21 @@ class WhisperEncoder(nn.Module):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
embed_dim = config.d_model
|
||||
|
||||
self.pos_embed_type = WhisperPosEmbedType(
|
||||
getattr(config, "pos_embed", "sinusoidal")
|
||||
)
|
||||
self.num_mel_bins = config.num_mel_bins
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
||||
is_causal = getattr(config, "is_causal", False)
|
||||
Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1)
|
||||
|
||||
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
|
||||
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
|
||||
|
||||
self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.encoder_layers,
|
||||
lambda prefix: WhisperEncoderLayer(
|
||||
@@ -507,29 +485,54 @@ class WhisperEncoder(nn.Module):
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
maybe_fp32_init_ctx = (
|
||||
set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
|
||||
)
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
maybe_fp32_init_ctx,
|
||||
if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
|
||||
raise ValueError(
|
||||
"Only NOPE position embeddings are supported "
|
||||
f"for causal models, but got {self.pos_embed_type}"
|
||||
)
|
||||
elif self.pos_embed_type in (
|
||||
WhisperPosEmbedType.SINUSOIDAL,
|
||||
WhisperPosEmbedType.LEARNED,
|
||||
):
|
||||
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape)
|
||||
maybe_fp32_init_ctx = (
|
||||
set_default_torch_dtype(torch.float32)
|
||||
if init_in_fp32
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
||||
with (
|
||||
torch.no_grad(),
|
||||
maybe_fp32_init_ctx,
|
||||
):
|
||||
self.embed_positions = nn.Embedding(
|
||||
self.max_source_positions, embed_dim
|
||||
)
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape)
|
||||
)
|
||||
|
||||
def forward_conv(
|
||||
self, input_features: torch.Tensor | list[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
hidden_states = []
|
||||
input_is_batched = False
|
||||
for features in input_features:
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
|
||||
embeds.dtype
|
||||
)
|
||||
|
||||
if self.pos_embed_type in (
|
||||
WhisperPosEmbedType.SINUSOIDAL,
|
||||
WhisperPosEmbedType.LEARNED,
|
||||
):
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
embeds = (
|
||||
embeds + self.embed_positions.weight[: embeds.size(-2), :]
|
||||
).to(embeds.dtype)
|
||||
elif self.pos_embed_type == WhisperPosEmbedType.NOPE:
|
||||
embeds = embeds.transpose(-1, -2).to(embeds.dtype)
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}")
|
||||
|
||||
hidden_states.append(embeds)
|
||||
input_is_batched = embeds.ndim > 2
|
||||
# Input to MHA must be B x T x D
|
||||
@@ -539,12 +542,19 @@ class WhisperEncoder(nn.Module):
|
||||
else:
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
||||
hidden_states = self.forward_conv(input_features)
|
||||
return self.forward_layers(hidden_states)
|
||||
|
||||
|
||||
class WhisperDecoder(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
Reference in New Issue
Block a user