Voxtral (#20970)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
committed by
GitHub
parent
4ffd963fa0
commit
e7e3e6d263
@@ -3,6 +3,7 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
@@ -13,6 +14,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
|
||||
VllmConfig)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -26,6 +28,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
@@ -178,6 +181,7 @@ class WhisperAttention(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
standalone_encoder: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@@ -213,16 +217,24 @@ class WhisperAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=self.attn_type,
|
||||
)
|
||||
if standalone_encoder:
|
||||
self.attn = MultiHeadAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
else:
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=self.attn_type,
|
||||
)
|
||||
|
||||
def _init_qkv(
|
||||
self,
|
||||
@@ -357,7 +369,11 @@ class WhisperMLP(nn.Module):
|
||||
|
||||
class WhisperEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
is_standalone_encoder: bool = False):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@@ -371,6 +387,7 @@ class WhisperEncoderLayer(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
standalone_encoder=is_standalone_encoder,
|
||||
)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.mlp = WhisperMLP(
|
||||
@@ -462,10 +479,16 @@ class WhisperDecoderLayer(nn.Module):
|
||||
|
||||
class WhisperEncoder(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
is_standalone_encoder: bool = False,
|
||||
init_in_fp32: bool = False):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
embed_dim = config.d_model
|
||||
self.is_standalone_encoder = is_standalone_encoder
|
||||
self.num_mel_bins = config.num_mel_bins
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = (math.sqrt(embed_dim)
|
||||
@@ -480,17 +503,25 @@ class WhisperEncoder(nn.Module):
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
self.embed_positions = nn.Embedding(self.max_source_positions,
|
||||
embed_dim)
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.encoder_layers,
|
||||
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.layers"),
|
||||
prefix=f"{prefix}.layers",
|
||||
is_standalone_encoder=
|
||||
is_standalone_encoder),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
with torch.no_grad():
|
||||
maybe_fp32_init_ctx = set_default_torch_dtype(
|
||||
torch.float32) if init_in_fp32 else nullcontext()
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
maybe_fp32_init_ctx,
|
||||
):
|
||||
self.embed_positions = nn.Embedding(self.max_source_positions,
|
||||
embed_dim)
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape))
|
||||
|
||||
@@ -499,8 +530,10 @@ class WhisperEncoder(nn.Module):
|
||||
for features in input_features:
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
embeds = embeds.permute(1, 0)
|
||||
embeds = embeds + self.embed_positions.weight[:embeds.size(0), :]
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
embeds = (embeds +
|
||||
self.embed_positions.weight[:embeds.size(-2), :]).to(
|
||||
embeds.dtype)
|
||||
hidden_states.append(embeds)
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
|
||||
@@ -792,10 +825,14 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig, language: str,
|
||||
task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
model_config: ModelConfig, # not needed here
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: str,
|
||||
task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
# Whisper does not support encoder prompt.
|
||||
|
||||
Reference in New Issue
Block a user