[Voxtral] Add new streaming arch (#32861)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2026-01-23 12:41:52 +01:00
committed by GitHub
parent 5da4c7d789
commit 3f3f89529d
9 changed files with 767 additions and 313 deletions

View File

@@ -10,6 +10,12 @@ from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.llama import (
LlamaAttention,
@@ -17,11 +23,57 @@ from vllm.model_executor.models.llama import (
LlamaForCausalLM,
LlamaModel,
)
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionType
from .utils import AutoWeightsLoader
class MistralMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
gate_up_proj_bias: bool | None = None,
prefix: str = "",
reduce_results: bool = True,
disable_tp: bool = False,
) -> None:
super().__init__()
gate_up_proj_bias = bias if gate_up_proj_bias is None else gate_up_proj_bias
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=gate_up_proj_bias,
quant_config=quant_config,
disable_tp=disable_tp,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=disable_tp,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class MistralAttention(LlamaAttention):
def __init__(
self,
@@ -114,6 +166,50 @@ class MistralDecoderLayer(LlamaDecoderLayer):
self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj
self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj
if getattr(config, "ada_rms_norm_t_cond", False):
self.ada_rms_norm_t_cond = nn.Sequential(
ColumnParallelLinear(
input_size=config.hidden_size,
output_size=config.ada_rms_norm_t_cond_dim,
bias=False,
return_bias=False,
),
nn.GELU(),
RowParallelLinear(
input_size=config.ada_rms_norm_t_cond_dim,
output_size=config.hidden_size,
bias=False,
return_bias=False,
),
)
else:
self.ada_rms_norm_t_cond = None
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
t_cond: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
if self.ada_rms_norm_t_cond is not None:
assert t_cond is not None
hidden_states = hidden_states * (1 + self.ada_rms_norm_t_cond(t_cond))
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class MistralModel(LlamaModel):
@@ -126,6 +222,18 @@ class MistralModel(LlamaModel):
):
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
t_cond: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
return super().forward(
input_ids, positions, intermediate_tensors, inputs_embeds, t_cond=t_cond
)
class MistralForCausalLM(LlamaForCausalLM):
# Mistral: We don't support LoRA on the embedding layers.