diff --git a/tests/models/multimodal/generation/test_voxtral_streaming.py b/tests/models/multimodal/generation/test_voxtral_streaming.py new file mode 100644 index 000000000..5cdf6f171 --- /dev/null +++ b/tests/models/multimodal/generation/test_voxtral_streaming.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import asdict + +import pytest +from mistral_common.audio import Audio +from mistral_common.protocol.instruct.chunk import RawAudio +from mistral_common.protocol.transcription.request import ( + StreamingMode, + TranscriptionRequest, +) +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + +from vllm import LLM, EngineArgs, SamplingParams +from vllm.assets.audio import AudioAsset + + +def _get_engine(path: str) -> LLM: + engine_args = EngineArgs( + model=path, + max_model_len=8192, + max_num_seqs=1, + limit_mm_per_prompt={"audio": 1}, + config_format="mistral", + load_format="mistral", + tokenizer_mode="mistral", + enforce_eager=True, + gpu_memory_utilization=0.4, + ) + return LLM(**asdict(engine_args)) + + +@pytest.mark.skip(reason="Voxtral streaming is not yet public") +def test_voxtral_streaming_forward(): + audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] + + model_name = "mistralai/Voxtral-Mini-3B-Realtime-2602" + tokenizer = MistralTokenizer.from_hf_hub(model_name) + audio_config = tokenizer.instruct_tokenizer.tokenizer.audio + + def from_file(file_path: str): + audio = Audio.from_file(file_path, strict=False) + req = TranscriptionRequest( + audio=RawAudio.from_audio(audio), + streaming=StreamingMode.OFFLINE, + language=None, + ) + tokenized = tokenizer.instruct_tokenizer.encode_transcription(req) + + return (tokenized.tokens, tokenized.audios[0].audio_array) + + tokenized_list = [ + from_file(audio_asset.get_local_path()) for audio_asset in audio_assets + ] + + inputs = [] + sampling_params = [] + + for tokens, audio_array in tokenized_list: + num_samples = audio_array.shape[0] + max_tokens = ( + audio_config.num_audio_tokens(num_samples) + - audio_config.num_delay_tokens + - 1 + ) + sampling_params.append(SamplingParams(temperature=0.0, max_tokens=max_tokens)) + + input_dict = { + "multi_modal_data": {"audio": [(audio_array, None)]}, + "prompt_token_ids": tokens, + } + inputs.append(input_dict) + + llm = _get_engine(model_name) + outputs = llm.generate( + inputs, + sampling_params=sampling_params, + ) + + texts = [out.outputs[0].text for out in outputs] + expected = [ + ( + " First words I spoke in the original phonograph. " + "A little piece of practical poetry. Mary had a little lamb," + " it sleeps with quite a snow, and everywhere that Mary went, " + "the lamb was sure to go." + ), + ( + " And the 0-1 pitch on the way to Edgar Martinez. Swung on" + " the line. Down the left field line for OBS. Here comes Joy. " + "Here is Junior to third base. They're going to wave him in. " + "The throw to the plate will be late. The Mariners are going" + " to play. For the American League Championship, " + "I don't believe it. It just continues. My oh, my." + ), + ] + assert texts == expected diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 979cd9fdd..31858a365 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -404,6 +404,7 @@ class LlamaModel(nn.Module): positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None, inputs_embeds: torch.Tensor | None = None, + **extra_layer_kwargs, ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -422,7 +423,9 @@ class LlamaModel(nn.Module): ): if idx in self.aux_hidden_state_layers: aux_hidden_states.append(hidden_states + residual) - hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, residual = layer( + positions, hidden_states, residual, **extra_layer_kwargs + ) if not get_pp_group().is_last_rank: return IntermediateTensors( diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index a1d88d088..203169efc 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -10,6 +10,12 @@ from transformers import LlamaConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.llama import ( LlamaAttention, @@ -17,11 +23,57 @@ from vllm.model_executor.models.llama import ( LlamaForCausalLM, LlamaModel, ) +from vllm.sequence import IntermediateTensors from vllm.v1.attention.backend import AttentionType from .utils import AutoWeightsLoader +class MistralMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + gate_up_proj_bias: bool | None = None, + prefix: str = "", + reduce_results: bool = True, + disable_tp: bool = False, + ) -> None: + super().__init__() + gate_up_proj_bias = bias if gate_up_proj_bias is None else gate_up_proj_bias + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=gate_up_proj_bias, + quant_config=quant_config, + disable_tp=disable_tp, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=disable_tp, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + class MistralAttention(LlamaAttention): def __init__( self, @@ -114,6 +166,50 @@ class MistralDecoderLayer(LlamaDecoderLayer): self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj + if getattr(config, "ada_rms_norm_t_cond", False): + self.ada_rms_norm_t_cond = nn.Sequential( + ColumnParallelLinear( + input_size=config.hidden_size, + output_size=config.ada_rms_norm_t_cond_dim, + bias=False, + return_bias=False, + ), + nn.GELU(), + RowParallelLinear( + input_size=config.ada_rms_norm_t_cond_dim, + output_size=config.hidden_size, + bias=False, + return_bias=False, + ), + ) + else: + self.ada_rms_norm_t_cond = None + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + t_cond: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + if self.ada_rms_norm_t_cond is not None: + assert t_cond is not None + hidden_states = hidden_states * (1 + self.ada_rms_norm_t_cond(t_cond)) + + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + @support_torch_compile class MistralModel(LlamaModel): @@ -126,6 +222,18 @@ class MistralModel(LlamaModel): ): super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + t_cond: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: + return super().forward( + input_ids, positions, intermediate_tensors, inputs_embeds, t_cond=t_cond + ) + class MistralForCausalLM(LlamaForCausalLM): # Mistral: We don't support LoRA on the embedding layers. diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 6153f12e4..c3d68967a 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -4,7 +4,7 @@ import inspect import math from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property +from functools import cached_property, partial from math import ceil from typing import Literal, cast @@ -33,7 +33,11 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import SupportsPP from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.whisper import WhisperEncoder +from vllm.model_executor.models.whisper import ( + WhisperEncoder, + _create_fake_bias_for_k_proj, +) +from vllm.model_executor.models.whisper_causal import WhisperCausalEncoder from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, @@ -543,6 +547,7 @@ class VoxtralForConditionalGeneration( } ).named_parameters() ) + weights = _create_fake_bias_for_k_proj(weights, ".wk.weight") loaded_weights = set() @@ -730,6 +735,10 @@ class VoxtralEncoderModel(nn.Module): r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501 r"whisper_encoder.layers.\1.mlp.fc2.\2", ), + ( + r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(weight|bias)", + r"whisper_encoder.layers.\1.mlp.fc3.\2", + ), # noqa: E501 ( r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder.layers.\1.final_layer_norm.\2", @@ -749,10 +758,15 @@ class VoxtralEncoderModel(nn.Module): super().__init__() self.config = cast(WhisperConfig, vllm_config.model_config.hf_config) self.dtype: torch.dtype = vllm_config.model_config.dtype - self.whisper_encoder = WhisperEncoder( + self.is_causal = getattr(self.config, "is_causal", False) + if self.is_causal: + WhisperEncoderCls = WhisperCausalEncoder + else: + WhisperEncoderCls = partial(WhisperEncoder, init_in_fp32=True) + + self.whisper_encoder = WhisperEncoderCls( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "whisper_encoder"), - init_in_fp32=True, ) mel_filters = mel_filter_bank( num_frequency_bins=1 + self.config.window_size // 2, @@ -843,6 +857,22 @@ class VoxtralEncoderModel(nn.Module): ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] + params_mapping = [] + + if self.is_causal: + # For `WhisperCausalEncoder` we need + # some more renaming + stacked_params_mapping.extend( + [ + (".mlp.gate_up_proj", ".mlp.fc1", 0), + (".mlp.gate_up_proj", ".mlp.fc3", 1), + ] + ) + params_mapping.extend( + [ + (".mlp.down_proj", ".mlp.fc2"), + ] + ) params_dict = dict(self.named_parameters()) name, loaded_weight = weight @@ -860,6 +890,11 @@ class VoxtralEncoderModel(nn.Module): weight_loader(param, loaded_weight, shard_id) break else: + for param_name, weight_name in params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/voxtral_streaming.py b/vllm/model_executor/models/voxtral_streaming.py index 730576a40..fb20d986a 100644 --- a/vllm/model_executor/models/voxtral_streaming.py +++ b/vllm/model_executor/models/voxtral_streaming.py @@ -112,6 +112,18 @@ class TimeEmbedding(torch.nn.Module): return torch.cat((emb.cos(), emb.sin()), dim=-1) # (B, D) or (B, T, D) +def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor: + # 1. Multiply by the scaling factor (e.g. 4) + base = input_tensor * scaling + + # 2. Create the offsets, e.g. [0, 1, 2, 3] + offsets = torch.arange(scaling, device=input_tensor.device) + + # 3. Use broadcasting, e.g. (N, 1) + (4,) results in (N, 4) + # Then flatten back to 1D + return (base.unsqueeze(1) + offsets).view(-1) + + @MULTIMODAL_REGISTRY.register_processor( VoxtralStreamingMultiModalProcessor, info=VoxtralProcessingInfo, @@ -175,8 +187,9 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size ) - audio_hidden_states = self.whisper_encoder.whisper_encoder.forward_layers( - inputs_embeds + whisper_positions = _expand_tensor(positions, pool_size) + audio_hidden_states = self.whisper_encoder.whisper_encoder( + inputs_embeds, whisper_positions ) num_tokens, audio_hidden_size = audio_hidden_states.shape @@ -197,10 +210,14 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): device=inputs_embeds.device, dtype=inputs_embeds.dtype, ) - inputs_embeds = inputs_embeds + self.time_embedding(time_tensor) + t_cond = self.time_embedding(time_tensor) hidden_states = self.language_model.model( - input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + t_cond=t_cond, ) return hidden_states diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 7e3d470a5..ec3e5818e 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -5,7 +5,6 @@ import enum import math from collections.abc import Iterable, Mapping, Sequence from contextlib import nullcontext -from functools import partial from typing import Annotated, Literal, cast import numpy as np @@ -39,8 +38,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.whisper_utils import ( ISO639_1_SUPPORTED_LANGS, - WhisperAttentionWithBlockPooling, - WhisperCausalConv1d, ) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( @@ -78,7 +75,7 @@ logger = init_logger(__name__) class WhisperPosEmbedType(enum.Enum): SINUSOIDAL = "sinusoidal" - NOPE = "nope" + ROPE = "rope" LEARNED = "learned" @@ -140,7 +137,6 @@ class WhisperAttention(nn.Module): bias: bool = True, attn_type: AttentionType = AttentionType.DECODER, per_layer_sliding_window: int | None = None, - block_pool_size: int = 1, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", @@ -199,14 +195,7 @@ class WhisperAttention(nn.Module): attn_type=self.attn_type, ) else: # AttentionType.DECODER (regular decoder self-attention) - if block_pool_size > 1: - attn_cls = partial( - WhisperAttentionWithBlockPooling, block_pool_size=block_pool_size - ) - else: - attn_cls = Attention - - self.attn = attn_cls( + self.attn = Attention( self.num_heads, self.head_dim, self.scaling, @@ -351,9 +340,7 @@ class WhisperEncoderLayer(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - is_causal = getattr(config, "is_causal", False) sliding_window = getattr(config, "sliding_window", None) - block_pool_size = getattr(config, "block_pool_size", 1) cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config @@ -361,8 +348,7 @@ class WhisperEncoderLayer(nn.Module): self.self_attn = WhisperAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, - attn_type=AttentionType.DECODER if is_causal else AttentionType.ENCODER, - block_pool_size=block_pool_size, + attn_type=AttentionType.ENCODER, per_layer_sliding_window=sliding_window, cache_config=cache_config, quant_config=quant_config, @@ -470,13 +456,8 @@ class WhisperEncoder(nn.Module): self.max_source_positions = config.max_source_positions self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 - self.is_causal = getattr(config, "is_causal", False) - Conv1d = ( - WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1) - ) - - self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3) - self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3) + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3, padding=1) self.total_stride = self.conv1.stride[0] * self.conv2.stride[0] self.start_layer, self.end_layer, self.layers = make_layers( @@ -488,33 +469,29 @@ class WhisperEncoder(nn.Module): ) self.layer_norm = nn.LayerNorm(config.d_model) - if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE: - raise ValueError( - "Only NOPE position embeddings are supported " - f"for causal models, but got {self.pos_embed_type}" - ) - elif self.pos_embed_type in ( + if self.pos_embed_type not in ( WhisperPosEmbedType.SINUSOIDAL, WhisperPosEmbedType.LEARNED, ): - maybe_fp32_init_ctx = ( - set_default_torch_dtype(torch.float32) - if init_in_fp32 - else nullcontext() + raise ValueError( + "Only sinusoidal or learned position embeddings are supported " + f"for non-causal models, but got {self.pos_embed_type}" ) - with ( - torch.no_grad(), - maybe_fp32_init_ctx, - ): - self.embed_positions = nn.Embedding( - self.max_source_positions, embed_dim - ) - self.embed_positions.weight.copy_( - sinusoids(*self.embed_positions.weight.shape) - ) + maybe_fp32_init_ctx = ( + set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext() + ) - def forward_conv( + with ( + torch.no_grad(), + maybe_fp32_init_ctx, + ): + self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + self.embed_positions.weight.copy_( + sinusoids(*self.embed_positions.weight.shape) + ) + + def forward( self, input_features: torch.Tensor | list[torch.Tensor] ) -> torch.Tensor: hidden_states = [] @@ -523,44 +500,26 @@ class WhisperEncoder(nn.Module): embeds = nn.functional.gelu(self.conv1(features)) embeds = nn.functional.gelu(self.conv2(embeds)) - if self.pos_embed_type in ( - WhisperPosEmbedType.SINUSOIDAL, - WhisperPosEmbedType.LEARNED, - ): - embeds = embeds.transpose(-1, -2) - embeds = ( - embeds + self.embed_positions.weight[: embeds.size(-2), :] - ).to(embeds.dtype) - elif self.pos_embed_type == WhisperPosEmbedType.NOPE: - embeds = embeds.transpose(-1, -2).to(embeds.dtype) - else: - raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}") + embeds = embeds.transpose(-1, -2) + embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to( + embeds.dtype + ) hidden_states.append(embeds) input_is_batched = embeds.ndim > 2 # Input to MHA must be B x T x D - if input_is_batched or self.is_causal: + if input_is_batched: # Models using WhisperEncoder may handle batching internally. - # If WhisperEncoder is causal, sequences - # are not padded to have identical seq length (T) - # => concat over feature dim hidden_states = torch.cat(hidden_states) else: hidden_states = torch.stack(hidden_states, dim=0) - return hidden_states - - def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor: for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states) hidden_states = self.layer_norm(hidden_states) return hidden_states - def forward(self, input_features: torch.Tensor | list[torch.Tensor]): - hidden_states = self.forward_conv(input_features) - return self.forward_layers(hidden_states) - @support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1}) class WhisperDecoder(nn.Module): @@ -978,19 +937,19 @@ class WhisperForConditionalGeneration( loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."]) # add fake zeros bias for k_proj to state_dict - weights = _create_fake_bias_for_k_proj(weights) + weights = _create_fake_bias_for_k_proj(weights, ".k_proj.weight") return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def _create_fake_bias_for_k_proj( - weights: Iterable[tuple[str, torch.Tensor]], + weights: Iterable[tuple[str, torch.Tensor]], fake_bias_key_name: str ) -> Iterable[tuple[str, torch.Tensor]]: """ Create full zeros bias for k_proj weight in self-attn and x-attn layers. So that the bias for k_proj in qkv_proj can be initialized with zeros. """ for name, weight in weights: - if name.endswith(".k_proj.weight"): + if name.endswith(fake_bias_key_name): bias = torch.zeros(weight.size(0)) bias_name = name.replace("weight", "bias") yield from [(name, weight), (bias_name, bias)] diff --git a/vllm/model_executor/models/whisper_causal.py b/vllm/model_executor/models/whisper_causal.py new file mode 100644 index 000000000..c547d5d3f --- /dev/null +++ b/vllm/model_executor/models/whisper_causal.py @@ -0,0 +1,465 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +import functools +import math +from dataclasses import replace +from functools import partial + +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.models.mistral import MistralMLP +from vllm.model_executor.models.whisper import WhisperPosEmbedType +from vllm.v1.attention.backend import ( + AttentionBackend, + AttentionMetadata, + AttentionType, + CommonAttentionMetadata, + subclass_attention_backend_with_overrides, +) +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.attention.selector import get_attn_backend +from vllm.v1.kv_cache_interface import AttentionSpec + +from .utils import make_layers + +CausalRMSNorm = partial(RMSNorm, eps=1e-5) + + +def _pad1d( + x: torch.Tensor, + paddings: tuple[int, int], + mode: str = "constant", + value: float = 0.0, +) -> torch.Tensor: + """Tiny wrapper around F.pad, just to allow for + reflect padding on small input. + If this is the case, we insert extra 0 padding + to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == "reflect": + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +class WhisperCausalConv1d(nn.Conv1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + ) -> None: + super().__init__( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + self._stride = self.stride[0] + self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1 + self._padding_total = self._effective_kernel_size - self._stride + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n_frames = ( + x.shape[-1] - self._effective_kernel_size + self._padding_total + ) / self._stride + 1 + target_length = (math.ceil(n_frames) - 1) * self._stride + ( + self._effective_kernel_size - self._padding_total + ) + extra_padding = target_length - x.shape[-1] + x = _pad1d(x, (self._padding_total, extra_padding), mode="constant") + return super().forward(x) + + +@functools.lru_cache +def create_whisper_attention_backend_with_block_pooling( + underlying_attn_backend: AttentionBackend, block_pool_size: int +) -> type[AttentionBackend]: + prefix = "WhisperCausalAttentionWithBlockPooling_" + underlying_builder = underlying_attn_backend.get_builder_cls() + + class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + assert kv_cache_spec.num_kv_heads % block_pool_size == 0 + kv_cache_spec = replace( + kv_cache_spec, + block_size=kv_cache_spec.block_size * block_pool_size, + num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size, + ) + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: + new_common_attn_metadata = copy.deepcopy(common_attn_metadata) + new_common_attn_metadata.query_start_loc *= block_pool_size + new_common_attn_metadata.query_start_loc_cpu *= block_pool_size + new_common_attn_metadata.seq_lens *= block_pool_size + new_common_attn_metadata._seq_lens_cpu *= block_pool_size + new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size + new_common_attn_metadata.num_actual_tokens *= block_pool_size + new_common_attn_metadata.max_query_len *= block_pool_size + new_common_attn_metadata.max_seq_len *= block_pool_size + original_slot_mapping = common_attn_metadata.slot_mapping + common_prefix_len *= block_pool_size + new_common_attn_metadata.slot_mapping = ( + ( + original_slot_mapping.unsqueeze(1) * block_pool_size + + torch.arange(block_pool_size, device=original_slot_mapping.device) + ) + .flatten() + .clamp(min=-1) + ) + return super().build( + common_prefix_len, new_common_attn_metadata, fast_build + ) + + if not issubclass(underlying_attn_backend, FlashAttentionBackend): + raise NotImplementedError( + f"{underlying_attn_backend} is not yet supported." + "Contributions to support more backends are much " + "appreciated." + ) + + attn_backend = subclass_attention_backend_with_overrides( + name_prefix=prefix, + attention_backend_cls=underlying_attn_backend, + overrides={ + "get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder, + "get_kv_cache_shape": lambda num_blocks, + block_size, + num_kv_heads, + head_size, + cache_dtype_str: ( + 2, + num_blocks, + # we stretch each block by `block_pool_size` + block_size * block_pool_size, + num_kv_heads // block_pool_size, + head_size, + ), # TODO: generalize to other backends + }, + ) + + return attn_backend + + +class WhisperCausalAttentionWithBlockPooling(Attention): + """Attention layer with block pooling.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int | None = None, + alibi_slopes: list[float] | None = None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + logits_soft_cap: float | None = None, + per_layer_sliding_window: int | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + block_pool_size: int = 1, + attn_backend: type[AttentionBackend] | None = None, + **extra_impl_args, + ) -> None: + self.block_pool_size = block_pool_size + dtype = torch.get_default_dtype() + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + block_size = 16 + + underlying_attn_backend = get_attn_backend( + head_size, + dtype, + kv_cache_dtype, + block_size, + attn_type=attn_type, + ) + attn_backend = create_whisper_attention_backend_with_block_pooling( + underlying_attn_backend, block_pool_size + ) + + super().__init__( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=logits_soft_cap, + per_layer_sliding_window=per_layer_sliding_window, + prefix=prefix, + attn_type=attn_type, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, + attn_backend=attn_backend, + **extra_impl_args, + ) + + def get_kv_cache_spec(self, vllm_config: VllmConfig): + kv_cache_spec = super().get_kv_cache_spec(vllm_config) + assert isinstance(kv_cache_spec, AttentionSpec) + kv_cache_spec = replace( + kv_cache_spec, + num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads, + ) + return kv_cache_spec + + +class WhisperCausalAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + max_position_embeddings: int, + bias: bool = True, + attn_type: AttentionType = AttentionType.DECODER, + per_layer_sliding_window: int | None = None, + block_pool_size: int = 1, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.embed_dim = embed_dim + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + if self.total_num_heads >= tp_size: + # Number of heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_heads % tp_size == 0 + else: + # Number of heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_heads == 0 + self.num_kv_heads = max(1, self.total_num_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.attn_type = attn_type + + self.scaling = self.head_dim**-0.5 + + self._init_qkv(embed_dim, bias, quant_config, prefix=prefix) + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=embed_dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + assert block_pool_size > 1, ( + f"Causal attention only supports block_pool_size>1, not {block_pool_size}." + ) + self.attn = WhisperCausalAttentionWithBlockPooling( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=AttentionType.DECODER, + per_layer_sliding_window=per_layer_sliding_window, + block_pool_size=block_pool_size, + ) + + assert per_layer_sliding_window is not None, ( + "rope can only used in combination with a sliding window" + ) + self._init_rotary_emb(max_position_embeddings) + + def _init_rotary_emb(self, max_position_embeddings: int) -> None: + self.rotary_emb = get_rope( + self.head_dim, + max_position=max_position_embeddings, + is_neox_style=False, + rope_parameters={"rope_theta": 1e6}, + ) + + def _init_qkv( + self, + embed_dim: int, + bias: bool = True, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + self.qkv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor | None = None, + ): + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + assert positions is not None + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v) + + output, _ = self.out_proj(attn_output) + + return output + + +class WhisperCausalEncoderLayer(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + sliding_window = getattr(config, "sliding_window", None) + block_pool_size = config.block_pool_size + assert block_pool_size > 1 + + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.embed_dim = config.d_model + self.head_dim = self.embed_dim // config.encoder_attention_heads + self.self_attn = WhisperCausalAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + head_dim=config.encoder_head_dim, + max_position_embeddings=config.max_position_embeddings, + block_pool_size=block_pool_size, + per_layer_sliding_window=sliding_window, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.self_attn_layer_norm = CausalRMSNorm(self.embed_dim) + + self.mlp = MistralMLP( + hidden_size=config.d_model, + intermediate_size=config.encoder_ffn_dim, + hidden_act="silu", + quant_config=quant_config, + bias=True, + gate_up_proj_bias=False, + prefix=f"{prefix}.mlp", + ) + self.final_layer_norm = CausalRMSNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor | None = None, + ): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states, positions=positions) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class WhisperCausalEncoder(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + embed_dim = config.d_model + + assert WhisperPosEmbedType(config.pos_embed) == WhisperPosEmbedType.ROPE + assert config.is_causal + + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.conv1 = WhisperCausalConv1d(self.num_mel_bins, embed_dim, kernel_size=3) + self.conv2 = WhisperCausalConv1d(embed_dim, embed_dim, stride=2, kernel_size=3) + + self.total_stride = self.conv1.stride[0] * self.conv2.stride[0] + self.start_layer, self.end_layer, self.layers = make_layers( + config.encoder_layers, + lambda prefix: WhisperCausalEncoderLayer( + vllm_config=vllm_config, prefix=f"{prefix}.layers" + ), + prefix=f"{prefix}.layers", + ) + self.layer_norm = CausalRMSNorm(config.d_model) + + def forward_conv( + self, input_features: torch.Tensor | list[torch.Tensor] + ) -> torch.Tensor: + hidden_states = [] + for features in input_features: + embeds = nn.functional.gelu(self.conv1(features)) + embeds = nn.functional.gelu(self.conv2(embeds)) + + embeds = embeds.transpose(-1, -2).to(embeds.dtype) + + hidden_states.append(embeds) + + hidden_states = torch.cat(hidden_states) + + return hidden_states + + def forward( + self, hidden_states: torch.Tensor, positions: torch.Tensor + ) -> torch.Tensor: + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states, positions) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states diff --git a/vllm/model_executor/models/whisper_utils.py b/vllm/model_executor/models/whisper_utils.py index 4d9f7ccf0..4dc7e430c 100644 --- a/vllm/model_executor/models/whisper_utils.py +++ b/vllm/model_executor/models/whisper_utils.py @@ -1,27 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import copy -import functools -import math -from dataclasses import replace -import torch -import torch.nn.functional as F -from torch import nn - -from vllm.attention.layer import Attention -from vllm.config import CacheConfig, VllmConfig -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.v1.attention.backend import ( - AttentionBackend, - AttentionMetadata, - AttentionType, - CommonAttentionMetadata, - subclass_attention_backend_with_overrides, -) -from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend -from vllm.v1.attention.selector import get_attn_backend -from vllm.v1.kv_cache_interface import AttentionSpec # From https://platform.openai.com/docs/guides/speech-to-text/supported-languages ISO639_1_SUPPORTED_LANGS = { @@ -83,215 +62,3 @@ ISO639_1_SUPPORTED_LANGS = { "vi": "Vietnamese", "cy": "Welsh", } - - -def _pad1d( - x: torch.Tensor, - paddings: tuple[int, int], - mode: str = "constant", - value: float = 0.0, -) -> torch.Tensor: - """Tiny wrapper around F.pad, just to allow for - reflect padding on small input. - If this is the case, we insert extra 0 padding - to the right before the reflection happen. - """ - length = x.shape[-1] - padding_left, padding_right = paddings - assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) - if mode == "reflect": - max_pad = max(padding_left, padding_right) - extra_pad = 0 - if length <= max_pad: - extra_pad = max_pad - length + 1 - x = F.pad(x, (0, extra_pad)) - padded = F.pad(x, paddings, mode, value) - end = padded.shape[-1] - extra_pad - return padded[..., :end] - else: - return F.pad(x, paddings, mode, value) - - -class WhisperCausalConv1d(nn.Conv1d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - bias: bool = True, - ) -> None: - super().__init__( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - bias=bias, - ) - self._stride = self.stride[0] - self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1 - self._padding_total = self._effective_kernel_size - self._stride - - def forward(self, x: torch.Tensor) -> torch.Tensor: - n_frames = ( - x.shape[-1] - self._effective_kernel_size + self._padding_total - ) / self._stride + 1 - target_length = (math.ceil(n_frames) - 1) * self._stride + ( - self._effective_kernel_size - self._padding_total - ) - extra_padding = target_length - x.shape[-1] - x = _pad1d(x, (self._padding_total, extra_padding), mode="constant") - return super().forward(x) - - -@functools.lru_cache -def create_whisper_attention_backend_with_block_pooling( - underlying_attn_backend: AttentionBackend, block_pool_size: int -) -> type[AttentionBackend]: - prefix = "WhisperAttentionWithBlockPooling_" - underlying_builder = underlying_attn_backend.get_builder_cls() - - class WhisperAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore - def __init__( - self, - kv_cache_spec: AttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - ): - assert kv_cache_spec.num_kv_heads % block_pool_size == 0 - kv_cache_spec = replace( - kv_cache_spec, - block_size=kv_cache_spec.block_size * block_pool_size, - num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size, - ) - super().__init__(kv_cache_spec, layer_names, vllm_config, device) - - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ) -> AttentionMetadata: - new_common_attn_metadata = copy.deepcopy(common_attn_metadata) - new_common_attn_metadata.query_start_loc *= block_pool_size - new_common_attn_metadata.query_start_loc_cpu *= block_pool_size - new_common_attn_metadata.seq_lens *= block_pool_size - new_common_attn_metadata._seq_lens_cpu *= block_pool_size - new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size - new_common_attn_metadata.num_actual_tokens *= block_pool_size - new_common_attn_metadata.max_query_len *= block_pool_size - new_common_attn_metadata.max_seq_len *= block_pool_size - original_slot_mapping = common_attn_metadata.slot_mapping - common_prefix_len *= block_pool_size - new_common_attn_metadata.slot_mapping = ( - ( - original_slot_mapping.unsqueeze(1) * block_pool_size - + torch.arange(block_pool_size, device=original_slot_mapping.device) - ) - .flatten() - .clamp(min=-1) - ) - return super().build( - common_prefix_len, new_common_attn_metadata, fast_build - ) - - if not issubclass(underlying_attn_backend, FlashAttentionBackend): - raise NotImplementedError( - f"{underlying_attn_backend} is not yet supported." - "Contributions to support more backends are much " - "appreciated." - ) - - attn_backend = subclass_attention_backend_with_overrides( - name_prefix=prefix, - attention_backend_cls=underlying_attn_backend, - overrides={ - "get_builder_cls": lambda: WhisperAttentionWithBlockPoolingBuilder, - "get_kv_cache_shape": lambda num_blocks, - block_size, - num_kv_heads, - head_size, - cache_dtype_str: ( - 2, - num_blocks, - # we stretch each block by `block_pool_size` - block_size * block_pool_size, - num_kv_heads // block_pool_size, - head_size, - ), # TODO: generalize to other backends - }, - ) - - return attn_backend - - -class WhisperAttentionWithBlockPooling(Attention): - """Attention layer with block pooling.""" - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int | None = None, - alibi_slopes: list[float] | None = None, - cache_config: CacheConfig | None = None, - quant_config: QuantizationConfig | None = None, - logits_soft_cap: float | None = None, - per_layer_sliding_window: int | None = None, - prefix: str = "", - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: str | None = None, - block_pool_size: int = 1, - attn_backend: type[AttentionBackend] | None = None, - **extra_impl_args, - ) -> None: - self.block_pool_size = block_pool_size - dtype = torch.get_default_dtype() - - if cache_config is not None: - kv_cache_dtype = cache_config.cache_dtype - block_size = cache_config.block_size - else: - kv_cache_dtype = "auto" - block_size = 16 - - underlying_attn_backend = get_attn_backend( - head_size, - dtype, - kv_cache_dtype, - block_size, - attn_type=attn_type, - ) - attn_backend = create_whisper_attention_backend_with_block_pooling( - underlying_attn_backend, block_pool_size - ) - - super().__init__( - num_heads=num_heads, - head_size=head_size, - scale=scale, - num_kv_heads=num_kv_heads, - alibi_slopes=alibi_slopes, - cache_config=cache_config, - quant_config=quant_config, - logits_soft_cap=logits_soft_cap, - per_layer_sliding_window=per_layer_sliding_window, - prefix=prefix, - attn_type=attn_type, - kv_sharing_target_layer_name=kv_sharing_target_layer_name, - attn_backend=attn_backend, - **extra_impl_args, - ) - - def get_kv_cache_spec(self, vllm_config: VllmConfig): - kv_cache_spec = super().get_kv_cache_spec(vllm_config) - assert isinstance(kv_cache_spec, AttentionSpec) - kv_cache_spec = replace( - kv_cache_spec, - num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads, - ) - return kv_cache_spec diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index 4776c892e..0bf282c8e 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -224,6 +224,7 @@ def _remap_mistral_audio_args(config: dict) -> dict: encoder_layers=encoder_args["n_layers"], encoder_ffn_dim=encoder_args["hidden_dim"], encoder_attention_heads=encoder_args["n_heads"], + encoder_head_dim=encoder_args["head_dim"], vocab_size=encoder_args["vocab_size"], max_source_positions=encoder_args["max_source_positions"], is_encoder_decoder=False, # Override WhisperConfig default @@ -231,6 +232,8 @@ def _remap_mistral_audio_args(config: dict) -> dict: sliding_window=sliding_window, block_pool_size=block_pool_size, pos_embed=encoder_args.get("pos_embed", "sinusoidal"), + # only needed for RoPE + max_position_embeddings=block_pool_size * config["max_position_embeddings"], ), } if quant_config: