[Voxtral] Add new streaming arch (#32861)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
5da4c7d789
commit
3f3f89529d
@@ -10,6 +10,12 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.llama import (
|
||||
LlamaAttention,
|
||||
@@ -17,11 +23,57 @@ from vllm.model_executor.models.llama import (
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .utils import AutoWeightsLoader
|
||||
|
||||
|
||||
class MistralMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
gate_up_proj_bias: bool | None = None,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
disable_tp: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
gate_up_proj_bias = bias if gate_up_proj_bias is None else gate_up_proj_bias
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=gate_up_proj_bias,
|
||||
quant_config=quant_config,
|
||||
disable_tp=disable_tp,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
disable_tp=disable_tp,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(x)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MistralAttention(LlamaAttention):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -114,6 +166,50 @@ class MistralDecoderLayer(LlamaDecoderLayer):
|
||||
self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj
|
||||
self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj
|
||||
|
||||
if getattr(config, "ada_rms_norm_t_cond", False):
|
||||
self.ada_rms_norm_t_cond = nn.Sequential(
|
||||
ColumnParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
output_size=config.ada_rms_norm_t_cond_dim,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
),
|
||||
nn.GELU(),
|
||||
RowParallelLinear(
|
||||
input_size=config.ada_rms_norm_t_cond_dim,
|
||||
output_size=config.hidden_size,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.ada_rms_norm_t_cond = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
t_cond: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
|
||||
if self.ada_rms_norm_t_cond is not None:
|
||||
assert t_cond is not None
|
||||
hidden_states = hidden_states * (1 + self.ada_rms_norm_t_cond(t_cond))
|
||||
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MistralModel(LlamaModel):
|
||||
@@ -126,6 +222,18 @@ class MistralModel(LlamaModel):
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
t_cond: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
return super().forward(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds, t_cond=t_cond
|
||||
)
|
||||
|
||||
|
||||
class MistralForCausalLM(LlamaForCausalLM):
|
||||
# Mistral: We don't support LoRA on the embedding layers.
|
||||
|
||||
Reference in New Issue
Block a user