[Voxtral] Add new streaming arch (#32861)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
5da4c7d789
commit
3f3f89529d
@@ -404,6 +404,7 @@ class LlamaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**extra_layer_kwargs,
|
||||
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
@@ -422,7 +423,9 @@ class LlamaModel(nn.Module):
|
||||
):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, residual, **extra_layer_kwargs
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
|
||||
@@ -10,6 +10,12 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.llama import (
|
||||
LlamaAttention,
|
||||
@@ -17,11 +23,57 @@ from vllm.model_executor.models.llama import (
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .utils import AutoWeightsLoader
|
||||
|
||||
|
||||
class MistralMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
gate_up_proj_bias: bool | None = None,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
disable_tp: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
gate_up_proj_bias = bias if gate_up_proj_bias is None else gate_up_proj_bias
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=gate_up_proj_bias,
|
||||
quant_config=quant_config,
|
||||
disable_tp=disable_tp,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
disable_tp=disable_tp,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(x)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MistralAttention(LlamaAttention):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -114,6 +166,50 @@ class MistralDecoderLayer(LlamaDecoderLayer):
|
||||
self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj
|
||||
self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj
|
||||
|
||||
if getattr(config, "ada_rms_norm_t_cond", False):
|
||||
self.ada_rms_norm_t_cond = nn.Sequential(
|
||||
ColumnParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
output_size=config.ada_rms_norm_t_cond_dim,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
),
|
||||
nn.GELU(),
|
||||
RowParallelLinear(
|
||||
input_size=config.ada_rms_norm_t_cond_dim,
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.ada_rms_norm_t_cond = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
t_cond: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
|
||||
if self.ada_rms_norm_t_cond is not None:
|
||||
assert t_cond is not None
|
||||
hidden_states = hidden_states * (1 + self.ada_rms_norm_t_cond(t_cond))
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MistralModel(LlamaModel):
|
||||
@@ -126,6 +222,18 @@ class MistralModel(LlamaModel):
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
t_cond: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
return super().forward(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds, t_cond=t_cond
|
||||
)
|
||||
|
||||
|
||||
class MistralForCausalLM(LlamaForCausalLM):
|
||||
# Mistral: We don't support LoRA on the embedding layers.
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import inspect
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from functools import cached_property, partial
|
||||
from math import ceil
|
||||
from typing import Literal, cast
|
||||
|
||||
@@ -33,7 +33,11 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import SupportsPP
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.whisper import WhisperEncoder
|
||||
from vllm.model_executor.models.whisper import (
|
||||
WhisperEncoder,
|
||||
_create_fake_bias_for_k_proj,
|
||||
)
|
||||
from vllm.model_executor.models.whisper_causal import WhisperCausalEncoder
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
@@ -543,6 +547,7 @@ class VoxtralForConditionalGeneration(
|
||||
}
|
||||
).named_parameters()
|
||||
)
|
||||
weights = _create_fake_bias_for_k_proj(weights, ".wk.weight")
|
||||
|
||||
loaded_weights = set()
|
||||
|
||||
@@ -730,6 +735,10 @@ class VoxtralEncoderModel(nn.Module):
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.mlp.fc2.\2",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(weight|bias)",
|
||||
r"whisper_encoder.layers.\1.mlp.fc3.\2",
|
||||
), # noqa: E501
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)",
|
||||
r"whisper_encoder.layers.\1.final_layer_norm.\2",
|
||||
@@ -749,10 +758,15 @@ class VoxtralEncoderModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
|
||||
self.dtype: torch.dtype = vllm_config.model_config.dtype
|
||||
self.whisper_encoder = WhisperEncoder(
|
||||
self.is_causal = getattr(self.config, "is_causal", False)
|
||||
if self.is_causal:
|
||||
WhisperEncoderCls = WhisperCausalEncoder
|
||||
else:
|
||||
WhisperEncoderCls = partial(WhisperEncoder, init_in_fp32=True)
|
||||
|
||||
self.whisper_encoder = WhisperEncoderCls(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "whisper_encoder"),
|
||||
init_in_fp32=True,
|
||||
)
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=1 + self.config.window_size // 2,
|
||||
@@ -843,6 +857,22 @@ class VoxtralEncoderModel(nn.Module):
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_mapping = []
|
||||
|
||||
if self.is_causal:
|
||||
# For `WhisperCausalEncoder` we need
|
||||
# some more renaming
|
||||
stacked_params_mapping.extend(
|
||||
[
|
||||
(".mlp.gate_up_proj", ".mlp.fc1", 0),
|
||||
(".mlp.gate_up_proj", ".mlp.fc3", 1),
|
||||
]
|
||||
)
|
||||
params_mapping.extend(
|
||||
[
|
||||
(".mlp.down_proj", ".mlp.fc2"),
|
||||
]
|
||||
)
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
name, loaded_weight = weight
|
||||
@@ -860,6 +890,11 @@ class VoxtralEncoderModel(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for param_name, weight_name in params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@@ -112,6 +112,18 @@ class TimeEmbedding(torch.nn.Module):
|
||||
return torch.cat((emb.cos(), emb.sin()), dim=-1) # (B, D) or (B, T, D)
|
||||
|
||||
|
||||
def _expand_tensor(input_tensor: torch.Tensor, scaling: int) -> torch.Tensor:
|
||||
# 1. Multiply by the scaling factor (e.g. 4)
|
||||
base = input_tensor * scaling
|
||||
|
||||
# 2. Create the offsets, e.g. [0, 1, 2, 3]
|
||||
offsets = torch.arange(scaling, device=input_tensor.device)
|
||||
|
||||
# 3. Use broadcasting, e.g. (N, 1) + (4,) results in (N, 4)
|
||||
# Then flatten back to 1D
|
||||
return (base.unsqueeze(1) + offsets).view(-1)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
VoxtralStreamingMultiModalProcessor,
|
||||
info=VoxtralProcessingInfo,
|
||||
@@ -175,8 +187,9 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
|
||||
inputs_embeds.shape[0] * pool_size, inputs_embeds.shape[1] // pool_size
|
||||
)
|
||||
|
||||
audio_hidden_states = self.whisper_encoder.whisper_encoder.forward_layers(
|
||||
inputs_embeds
|
||||
whisper_positions = _expand_tensor(positions, pool_size)
|
||||
audio_hidden_states = self.whisper_encoder.whisper_encoder(
|
||||
inputs_embeds, whisper_positions
|
||||
)
|
||||
|
||||
num_tokens, audio_hidden_size = audio_hidden_states.shape
|
||||
@@ -197,10 +210,14 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
|
||||
device=inputs_embeds.device,
|
||||
dtype=inputs_embeds.dtype,
|
||||
)
|
||||
inputs_embeds = inputs_embeds + self.time_embedding(time_tensor)
|
||||
t_cond = self.time_embedding(time_tensor)
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
t_cond=t_cond,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -5,7 +5,6 @@ import enum
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Annotated, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
@@ -39,8 +38,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.whisper_utils import (
|
||||
ISO639_1_SUPPORTED_LANGS,
|
||||
WhisperAttentionWithBlockPooling,
|
||||
WhisperCausalConv1d,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
@@ -78,7 +75,7 @@ logger = init_logger(__name__)
|
||||
|
||||
class WhisperPosEmbedType(enum.Enum):
|
||||
SINUSOIDAL = "sinusoidal"
|
||||
NOPE = "nope"
|
||||
ROPE = "rope"
|
||||
LEARNED = "learned"
|
||||
|
||||
|
||||
@@ -140,7 +137,6 @@ class WhisperAttention(nn.Module):
|
||||
bias: bool = True,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
block_pool_size: int = 1,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
@@ -199,14 +195,7 @@ class WhisperAttention(nn.Module):
|
||||
attn_type=self.attn_type,
|
||||
)
|
||||
else: # AttentionType.DECODER (regular decoder self-attention)
|
||||
if block_pool_size > 1:
|
||||
attn_cls = partial(
|
||||
WhisperAttentionWithBlockPooling, block_pool_size=block_pool_size
|
||||
)
|
||||
else:
|
||||
attn_cls = Attention
|
||||
|
||||
self.attn = attn_cls(
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
@@ -351,9 +340,7 @@ class WhisperEncoderLayer(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
is_causal = getattr(config, "is_causal", False)
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
block_pool_size = getattr(config, "block_pool_size", 1)
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
@@ -361,8 +348,7 @@ class WhisperEncoderLayer(nn.Module):
|
||||
self.self_attn = WhisperAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
attn_type=AttentionType.DECODER if is_causal else AttentionType.ENCODER,
|
||||
block_pool_size=block_pool_size,
|
||||
attn_type=AttentionType.ENCODER,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
@@ -470,13 +456,8 @@ class WhisperEncoder(nn.Module):
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.is_causal = getattr(config, "is_causal", False)
|
||||
Conv1d = (
|
||||
WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1)
|
||||
)
|
||||
|
||||
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
|
||||
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
|
||||
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3, padding=1)
|
||||
|
||||
self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
@@ -488,33 +469,29 @@ class WhisperEncoder(nn.Module):
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
|
||||
raise ValueError(
|
||||
"Only NOPE position embeddings are supported "
|
||||
f"for causal models, but got {self.pos_embed_type}"
|
||||
)
|
||||
elif self.pos_embed_type in (
|
||||
if self.pos_embed_type not in (
|
||||
WhisperPosEmbedType.SINUSOIDAL,
|
||||
WhisperPosEmbedType.LEARNED,
|
||||
):
|
||||
maybe_fp32_init_ctx = (
|
||||
set_default_torch_dtype(torch.float32)
|
||||
if init_in_fp32
|
||||
else nullcontext()
|
||||
raise ValueError(
|
||||
"Only sinusoidal or learned position embeddings are supported "
|
||||
f"for non-causal models, but got {self.pos_embed_type}"
|
||||
)
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
maybe_fp32_init_ctx,
|
||||
):
|
||||
self.embed_positions = nn.Embedding(
|
||||
self.max_source_positions, embed_dim
|
||||
)
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape)
|
||||
)
|
||||
maybe_fp32_init_ctx = (
|
||||
set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
|
||||
)
|
||||
|
||||
def forward_conv(
|
||||
with (
|
||||
torch.no_grad(),
|
||||
maybe_fp32_init_ctx,
|
||||
):
|
||||
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_features: torch.Tensor | list[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
hidden_states = []
|
||||
@@ -523,44 +500,26 @@ class WhisperEncoder(nn.Module):
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
|
||||
if self.pos_embed_type in (
|
||||
WhisperPosEmbedType.SINUSOIDAL,
|
||||
WhisperPosEmbedType.LEARNED,
|
||||
):
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
embeds = (
|
||||
embeds + self.embed_positions.weight[: embeds.size(-2), :]
|
||||
).to(embeds.dtype)
|
||||
elif self.pos_embed_type == WhisperPosEmbedType.NOPE:
|
||||
embeds = embeds.transpose(-1, -2).to(embeds.dtype)
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}")
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
|
||||
embeds.dtype
|
||||
)
|
||||
|
||||
hidden_states.append(embeds)
|
||||
input_is_batched = embeds.ndim > 2
|
||||
# Input to MHA must be B x T x D
|
||||
if input_is_batched or self.is_causal:
|
||||
if input_is_batched:
|
||||
# Models using WhisperEncoder may handle batching internally.
|
||||
# If WhisperEncoder is causal, sequences
|
||||
# are not padded to have identical seq length (T)
|
||||
# => concat over feature dim
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
||||
hidden_states = self.forward_conv(input_features)
|
||||
return self.forward_layers(hidden_states)
|
||||
|
||||
|
||||
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
|
||||
class WhisperDecoder(nn.Module):
|
||||
@@ -978,19 +937,19 @@ class WhisperForConditionalGeneration(
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
|
||||
|
||||
# add fake zeros bias for k_proj to state_dict
|
||||
weights = _create_fake_bias_for_k_proj(weights)
|
||||
weights = _create_fake_bias_for_k_proj(weights, ".k_proj.weight")
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
def _create_fake_bias_for_k_proj(
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
weights: Iterable[tuple[str, torch.Tensor]], fake_bias_key_name: str
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Create full zeros bias for k_proj weight in self-attn and x-attn layers.
|
||||
So that the bias for k_proj in qkv_proj can be initialized with zeros.
|
||||
"""
|
||||
for name, weight in weights:
|
||||
if name.endswith(".k_proj.weight"):
|
||||
if name.endswith(fake_bias_key_name):
|
||||
bias = torch.zeros(weight.size(0))
|
||||
bias_name = name.replace("weight", "bias")
|
||||
yield from [(name, weight), (bias_name, bias)]
|
||||
|
||||
465
vllm/model_executor/models/whisper_causal.py
Normal file
465
vllm/model_executor/models/whisper_causal.py
Normal file
@@ -0,0 +1,465 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import functools
|
||||
import math
|
||||
from dataclasses import replace
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.models.mistral import MistralMLP
|
||||
from vllm.model_executor.models.whisper import WhisperPosEmbedType
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend_with_overrides,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from .utils import make_layers
|
||||
|
||||
CausalRMSNorm = partial(RMSNorm, eps=1e-5)
|
||||
|
||||
|
||||
def _pad1d(
|
||||
x: torch.Tensor,
|
||||
paddings: tuple[int, int],
|
||||
mode: str = "constant",
|
||||
value: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""Tiny wrapper around F.pad, just to allow for
|
||||
reflect padding on small input.
|
||||
If this is the case, we insert extra 0 padding
|
||||
to the right before the reflection happen.
|
||||
"""
|
||||
length = x.shape[-1]
|
||||
padding_left, padding_right = paddings
|
||||
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
||||
if mode == "reflect":
|
||||
max_pad = max(padding_left, padding_right)
|
||||
extra_pad = 0
|
||||
if length <= max_pad:
|
||||
extra_pad = max_pad - length + 1
|
||||
x = F.pad(x, (0, extra_pad))
|
||||
padded = F.pad(x, paddings, mode, value)
|
||||
end = padded.shape[-1] - extra_pad
|
||||
return padded[..., :end]
|
||||
else:
|
||||
return F.pad(x, paddings, mode, value)
|
||||
|
||||
|
||||
class WhisperCausalConv1d(nn.Conv1d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
self._stride = self.stride[0]
|
||||
self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1
|
||||
self._padding_total = self._effective_kernel_size - self._stride
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
n_frames = (
|
||||
x.shape[-1] - self._effective_kernel_size + self._padding_total
|
||||
) / self._stride + 1
|
||||
target_length = (math.ceil(n_frames) - 1) * self._stride + (
|
||||
self._effective_kernel_size - self._padding_total
|
||||
)
|
||||
extra_padding = target_length - x.shape[-1]
|
||||
x = _pad1d(x, (self._padding_total, extra_padding), mode="constant")
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_whisper_attention_backend_with_block_pooling(
|
||||
underlying_attn_backend: AttentionBackend, block_pool_size: int
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "WhisperCausalAttentionWithBlockPooling_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
assert kv_cache_spec.num_kv_heads % block_pool_size == 0
|
||||
kv_cache_spec = replace(
|
||||
kv_cache_spec,
|
||||
block_size=kv_cache_spec.block_size * block_pool_size,
|
||||
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
|
||||
)
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
new_common_attn_metadata = copy.deepcopy(common_attn_metadata)
|
||||
new_common_attn_metadata.query_start_loc *= block_pool_size
|
||||
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
|
||||
new_common_attn_metadata.seq_lens *= block_pool_size
|
||||
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
|
||||
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
|
||||
new_common_attn_metadata.num_actual_tokens *= block_pool_size
|
||||
new_common_attn_metadata.max_query_len *= block_pool_size
|
||||
new_common_attn_metadata.max_seq_len *= block_pool_size
|
||||
original_slot_mapping = common_attn_metadata.slot_mapping
|
||||
common_prefix_len *= block_pool_size
|
||||
new_common_attn_metadata.slot_mapping = (
|
||||
(
|
||||
original_slot_mapping.unsqueeze(1) * block_pool_size
|
||||
+ torch.arange(block_pool_size, device=original_slot_mapping.device)
|
||||
)
|
||||
.flatten()
|
||||
.clamp(min=-1)
|
||||
)
|
||||
return super().build(
|
||||
common_prefix_len, new_common_attn_metadata, fast_build
|
||||
)
|
||||
|
||||
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
|
||||
raise NotImplementedError(
|
||||
f"{underlying_attn_backend} is not yet supported."
|
||||
"Contributions to support more backends are much "
|
||||
"appreciated."
|
||||
)
|
||||
|
||||
attn_backend = subclass_attention_backend_with_overrides(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
overrides={
|
||||
"get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder,
|
||||
"get_kv_cache_shape": lambda num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
cache_dtype_str: (
|
||||
2,
|
||||
num_blocks,
|
||||
# we stretch each block by `block_pool_size`
|
||||
block_size * block_pool_size,
|
||||
num_kv_heads // block_pool_size,
|
||||
head_size,
|
||||
), # TODO: generalize to other backends
|
||||
},
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class WhisperCausalAttentionWithBlockPooling(Attention):
|
||||
"""Attention layer with block pooling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
logits_soft_cap: float | None = None,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
block_pool_size: int = 1,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
self.block_pool_size = block_pool_size
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
attn_backend = create_whisper_attention_backend_with_block_pooling(
|
||||
underlying_attn_backend, block_pool_size
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
per_layer_sliding_window=per_layer_sliding_window,
|
||||
prefix=prefix,
|
||||
attn_type=attn_type,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
attn_backend=attn_backend,
|
||||
**extra_impl_args,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig):
|
||||
kv_cache_spec = super().get_kv_cache_spec(vllm_config)
|
||||
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||
kv_cache_spec = replace(
|
||||
kv_cache_spec,
|
||||
num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads,
|
||||
)
|
||||
return kv_cache_spec
|
||||
|
||||
|
||||
class WhisperCausalAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int,
|
||||
bias: bool = True,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
block_pool_size: int = 1,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
if self.total_num_heads >= tp_size:
|
||||
# Number of heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
else:
|
||||
# Number of heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_heads // tp_size)
|
||||
self.head_dim = head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.attn_type = attn_type
|
||||
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=embed_dim,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
assert block_pool_size > 1, (
|
||||
f"Causal attention only supports block_pool_size>1, not {block_pool_size}."
|
||||
)
|
||||
self.attn = WhisperCausalAttentionWithBlockPooling(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=AttentionType.DECODER,
|
||||
per_layer_sliding_window=per_layer_sliding_window,
|
||||
block_pool_size=block_pool_size,
|
||||
)
|
||||
|
||||
assert per_layer_sliding_window is not None, (
|
||||
"rope can only used in combination with a sliding window"
|
||||
)
|
||||
self._init_rotary_emb(max_position_embeddings)
|
||||
|
||||
def _init_rotary_emb(self, max_position_embeddings: int) -> None:
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
is_neox_style=False,
|
||||
rope_parameters={"rope_theta": 1e6},
|
||||
)
|
||||
|
||||
def _init_qkv(
|
||||
self,
|
||||
embed_dim: int,
|
||||
bias: bool = True,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor | None = None,
|
||||
):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
assert positions is not None
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class WhisperCausalEncoderLayer(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
block_pool_size = config.block_pool_size
|
||||
assert block_pool_size > 1
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.embed_dim = config.d_model
|
||||
self.head_dim = self.embed_dim // config.encoder_attention_heads
|
||||
self.self_attn = WhisperCausalAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
head_dim=config.encoder_head_dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
block_pool_size=block_pool_size,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.self_attn_layer_norm = CausalRMSNorm(self.embed_dim)
|
||||
|
||||
self.mlp = MistralMLP(
|
||||
hidden_size=config.d_model,
|
||||
intermediate_size=config.encoder_ffn_dim,
|
||||
hidden_act="silu",
|
||||
quant_config=quant_config,
|
||||
bias=True,
|
||||
gate_up_proj_bias=False,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.final_layer_norm = CausalRMSNorm(self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor | None = None,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states, positions=positions)
|
||||
hidden_states = residual + hidden_states
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WhisperCausalEncoder(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
embed_dim = config.d_model
|
||||
|
||||
assert WhisperPosEmbedType(config.pos_embed) == WhisperPosEmbedType.ROPE
|
||||
assert config.is_causal
|
||||
|
||||
self.num_mel_bins = config.num_mel_bins
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.conv1 = WhisperCausalConv1d(self.num_mel_bins, embed_dim, kernel_size=3)
|
||||
self.conv2 = WhisperCausalConv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
|
||||
|
||||
self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.encoder_layers,
|
||||
lambda prefix: WhisperCausalEncoderLayer(
|
||||
vllm_config=vllm_config, prefix=f"{prefix}.layers"
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.layer_norm = CausalRMSNorm(config.d_model)
|
||||
|
||||
def forward_conv(
|
||||
self, input_features: torch.Tensor | list[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
hidden_states = []
|
||||
for features in input_features:
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
|
||||
embeds = embeds.transpose(-1, -2).to(embeds.dtype)
|
||||
|
||||
hidden_states.append(embeds)
|
||||
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, positions: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states, positions)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
@@ -1,27 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import functools
|
||||
import math
|
||||
from dataclasses import replace
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend_with_overrides,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
|
||||
ISO639_1_SUPPORTED_LANGS = {
|
||||
@@ -83,215 +62,3 @@ ISO639_1_SUPPORTED_LANGS = {
|
||||
"vi": "Vietnamese",
|
||||
"cy": "Welsh",
|
||||
}
|
||||
|
||||
|
||||
def _pad1d(
|
||||
x: torch.Tensor,
|
||||
paddings: tuple[int, int],
|
||||
mode: str = "constant",
|
||||
value: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""Tiny wrapper around F.pad, just to allow for
|
||||
reflect padding on small input.
|
||||
If this is the case, we insert extra 0 padding
|
||||
to the right before the reflection happen.
|
||||
"""
|
||||
length = x.shape[-1]
|
||||
padding_left, padding_right = paddings
|
||||
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
||||
if mode == "reflect":
|
||||
max_pad = max(padding_left, padding_right)
|
||||
extra_pad = 0
|
||||
if length <= max_pad:
|
||||
extra_pad = max_pad - length + 1
|
||||
x = F.pad(x, (0, extra_pad))
|
||||
padded = F.pad(x, paddings, mode, value)
|
||||
end = padded.shape[-1] - extra_pad
|
||||
return padded[..., :end]
|
||||
else:
|
||||
return F.pad(x, paddings, mode, value)
|
||||
|
||||
|
||||
class WhisperCausalConv1d(nn.Conv1d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
self._stride = self.stride[0]
|
||||
self._effective_kernel_size = (kernel_size - 1) * self.dilation[0] + 1
|
||||
self._padding_total = self._effective_kernel_size - self._stride
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
n_frames = (
|
||||
x.shape[-1] - self._effective_kernel_size + self._padding_total
|
||||
) / self._stride + 1
|
||||
target_length = (math.ceil(n_frames) - 1) * self._stride + (
|
||||
self._effective_kernel_size - self._padding_total
|
||||
)
|
||||
extra_padding = target_length - x.shape[-1]
|
||||
x = _pad1d(x, (self._padding_total, extra_padding), mode="constant")
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_whisper_attention_backend_with_block_pooling(
|
||||
underlying_attn_backend: AttentionBackend, block_pool_size: int
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "WhisperAttentionWithBlockPooling_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class WhisperAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
assert kv_cache_spec.num_kv_heads % block_pool_size == 0
|
||||
kv_cache_spec = replace(
|
||||
kv_cache_spec,
|
||||
block_size=kv_cache_spec.block_size * block_pool_size,
|
||||
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
|
||||
)
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
new_common_attn_metadata = copy.deepcopy(common_attn_metadata)
|
||||
new_common_attn_metadata.query_start_loc *= block_pool_size
|
||||
new_common_attn_metadata.query_start_loc_cpu *= block_pool_size
|
||||
new_common_attn_metadata.seq_lens *= block_pool_size
|
||||
new_common_attn_metadata._seq_lens_cpu *= block_pool_size
|
||||
new_common_attn_metadata._num_computed_tokens_cpu *= block_pool_size
|
||||
new_common_attn_metadata.num_actual_tokens *= block_pool_size
|
||||
new_common_attn_metadata.max_query_len *= block_pool_size
|
||||
new_common_attn_metadata.max_seq_len *= block_pool_size
|
||||
original_slot_mapping = common_attn_metadata.slot_mapping
|
||||
common_prefix_len *= block_pool_size
|
||||
new_common_attn_metadata.slot_mapping = (
|
||||
(
|
||||
original_slot_mapping.unsqueeze(1) * block_pool_size
|
||||
+ torch.arange(block_pool_size, device=original_slot_mapping.device)
|
||||
)
|
||||
.flatten()
|
||||
.clamp(min=-1)
|
||||
)
|
||||
return super().build(
|
||||
common_prefix_len, new_common_attn_metadata, fast_build
|
||||
)
|
||||
|
||||
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
|
||||
raise NotImplementedError(
|
||||
f"{underlying_attn_backend} is not yet supported."
|
||||
"Contributions to support more backends are much "
|
||||
"appreciated."
|
||||
)
|
||||
|
||||
attn_backend = subclass_attention_backend_with_overrides(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
overrides={
|
||||
"get_builder_cls": lambda: WhisperAttentionWithBlockPoolingBuilder,
|
||||
"get_kv_cache_shape": lambda num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
cache_dtype_str: (
|
||||
2,
|
||||
num_blocks,
|
||||
# we stretch each block by `block_pool_size`
|
||||
block_size * block_pool_size,
|
||||
num_kv_heads // block_pool_size,
|
||||
head_size,
|
||||
), # TODO: generalize to other backends
|
||||
},
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class WhisperAttentionWithBlockPooling(Attention):
|
||||
"""Attention layer with block pooling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
logits_soft_cap: float | None = None,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
block_pool_size: int = 1,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
self.block_pool_size = block_pool_size
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
attn_backend = create_whisper_attention_backend_with_block_pooling(
|
||||
underlying_attn_backend, block_pool_size
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
per_layer_sliding_window=per_layer_sliding_window,
|
||||
prefix=prefix,
|
||||
attn_type=attn_type,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
attn_backend=attn_backend,
|
||||
**extra_impl_args,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig):
|
||||
kv_cache_spec = super().get_kv_cache_spec(vllm_config)
|
||||
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||
kv_cache_spec = replace(
|
||||
kv_cache_spec,
|
||||
num_kv_heads=self.block_pool_size * kv_cache_spec.num_kv_heads,
|
||||
)
|
||||
return kv_cache_spec
|
||||
|
||||
Reference in New Issue
Block a user