Add Mistral Large 3 and Ministral 3 (#29757)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
Julien Denize
2025-12-02 11:29:00 +01:00
committed by GitHub
parent 8bbcf8b6e7
commit d8c6210eea
16 changed files with 724 additions and 30 deletions

View File

@@ -395,6 +395,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0
def _get_llama_4_scaling(
original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
) -> torch.Tensor:
scaling = 1 + scaling_beta * torch.log(
1 + torch.floor(positions / original_max_position_embeddings)
)
# Broadcast over num_heads and head_dim
return scaling[..., None, None]
class DeepseekV2Attention(nn.Module):
def __init__(
self,
@@ -481,7 +491,11 @@ class DeepseekV2Attention(nn.Module):
prefix=f"{prefix}.o_proj",
)
if config.rope_parameters["rope_type"] != "default":
config.rope_parameters["rope_type"] = "deepseek_yarn"
config.rope_parameters["rope_type"] = (
"deepseek_yarn"
if config.rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)
self.rotary_emb = get_rope(
qk_rope_head_dim,
@@ -491,7 +505,10 @@ class DeepseekV2Attention(nn.Module):
is_neox_style=False,
)
if config.rope_parameters["rope_type"] != "default":
if (
config.rope_parameters["rope_type"] != "default"
and config.rope_parameters["rope_type"] == "deepseek_yarn"
):
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
@@ -511,6 +528,7 @@ class DeepseekV2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
@@ -536,6 +554,11 @@ class DeepseekV2Attention(nn.Module):
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
# Apply llama 4 scaling if provided
if llama_4_scaling is not None:
q *= llama_4_scaling
# padding value to qk_head_dim for alignment
v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim], value=0
@@ -987,7 +1010,12 @@ class DeepseekV2MLAAttention(nn.Module):
)
if config.rope_parameters["rope_type"] != "default":
config.rope_parameters["rope_type"] = "deepseek_yarn"
config.rope_parameters["rope_type"] = (
"deepseek_yarn"
if config.rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
@@ -995,7 +1023,11 @@ class DeepseekV2MLAAttention(nn.Module):
rope_parameters=config.rope_parameters,
is_neox_style=False,
)
if config.rope_parameters["rope_type"] != "default":
if (
config.rope_parameters["rope_type"] != "default"
and config.rope_parameters["rope_type"] == "deepseek_yarn"
):
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
@@ -1064,8 +1096,9 @@ class DeepseekV2MLAAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
) -> torch.Tensor:
return self.mla_attn(positions, hidden_states)
return self.mla_attn(positions, hidden_states, llama_4_scaling)
class DeepseekV2DecoderLayer(nn.Module):
@@ -1155,6 +1188,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
# Self Attention
if residual is None:
@@ -1165,6 +1199,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
llama_4_scaling=llama_4_scaling,
)
if (
@@ -1266,8 +1301,24 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
llama_4_scaling: torch.Tensor | None
if llama_4_scaling_config is not None:
llama_4_scaling = _get_llama_4_scaling(
original_max_position_embeddings=llama_4_scaling_config[
"original_max_position_embeddings"
],
scaling_beta=llama_4_scaling_config["beta"],
positions=positions,
)
else:
llama_4_scaling = None
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(
positions, hidden_states, residual, llama_4_scaling
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -1325,6 +1376,7 @@ class DeepseekV2ForCausalLM(
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
}
model_cls = DeepseekV2Model
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -1355,7 +1407,7 @@ class DeepseekV2ForCausalLM(
"kv_a_proj_with_mqa",
]
self.model = DeepseekV2Model(
self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:

View File

@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import regex as re
import torch
from vllm.model_executor.models.deepseek_v2 import DeepseekV3ForCausalLM
class MistralLarge3ForCausalLM(DeepseekV3ForCausalLM):
# fmt: off
remapping = {
r"layers\.(\d+)\.attention_norm\.weight": r"model.layers.\1.input_layernorm.weight", # noqa: E501
r"layers\.(\d+)\.attention\.wq_a\.(\w+)": r"model.layers.\1.self_attn.q_a_proj.\2", # noqa: E501
r"layers\.(\d+)\.attention\.q_a_norm\.weight": r"model.layers.\1.self_attn.q_a_layernorm.weight", # noqa: E501
r"layers\.(\d+)\.attention\.wq_b\.(\w+)": r"model.layers.\1.self_attn.q_b_proj.\2", # noqa: E501
r"layers\.(\d+)\.attention\.wkv_a_with_mqa\.(\w+)": r"model.layers.\1.self_attn.kv_a_proj_with_mqa.\2", # noqa: E501
r"layers\.(\d+)\.attention\.kv_a_norm\.weight": r"model.layers.\1.self_attn.kv_a_layernorm.weight", # noqa: E501
r"layers\.(\d+)\.attention\.wkv_b\.(\w+)": r"model.layers.\1.self_attn.kv_b_proj.\2", # noqa: E501
r"layers\.(\d+)\.attention\.wo\.(\w+)": r"model.layers.\1.self_attn.o_proj.\2", # noqa: E501
r"layers\.(\d+)\.ffn_norm\.weight": r"model.layers.\1.post_attention_layernorm.weight", # noqa: E501
r"layers\.(\d+)\.feed_forward\.w1\.(\w+)": r"model.layers.\1.mlp.gate_proj.\2", # noqa: E501
r"layers\.(\d+)\.feed_forward\.w2\.(\w+)": r"model.layers.\1.mlp.down_proj.\2", # noqa: E501
r"layers\.(\d+)\.feed_forward\.w3\.(\w+)": r"model.layers.\1.mlp.up_proj.\2", # noqa: E501
r"layers\.(\d+)\.gate\.weight": r"model.layers.\1.mlp.gate.weight", # noqa: E501
r"layers\.(\d+)\.shared_experts\.w1\.(\w+)": r"model.layers.\1.mlp.shared_experts.gate_proj.\2", # noqa: E501
r"layers\.(\d+)\.shared_experts\.w2\.(\w+)": r"model.layers.\1.mlp.shared_experts.down_proj.\2", # noqa: E501
r"layers\.(\d+)\.shared_experts\.w3\.(\w+)": r"model.layers.\1.mlp.shared_experts.up_proj.\2", # noqa: E501
r"layers\.(\d+)\.experts\.(\d+)\.w1\.(\w+)": r"model.layers.\1.mlp.experts.\2.gate_proj.\3", # noqa: E501
r"layers\.(\d+)\.experts\.(\d+)\.w2\.(\w+)": r"model.layers.\1.mlp.experts.\2.down_proj.\3", # noqa: E501
r"layers\.(\d+)\.experts\.(\d+)\.w3\.(\w+)": r"model.layers.\1.mlp.experts.\2.up_proj.\3", # noqa: E501
r"norm\.weight": "model.norm.weight", # noqa: E501
r"tok_embeddings\.weight": "model.embed_tokens.weight", # noqa: E501
r"output\.weight": "lm_head.weight", # noqa: E501
}
# fmt: on
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
return super().load_weights(map(self._remap_mistral_to_ds, weights))
def _remap_mistral_to_ds(
self, weight: tuple[str, torch.Tensor]
) -> tuple[str, torch.Tensor]:
"""Remap Mistral parameters to DeepseekV2 parameters."""
name, loaded_weight = weight
for k, v in self.remapping.items():
match = re.fullmatch(k, name)
if match:
name = re.sub(k, v, name)
break
else:
raise ValueError(f"Cannot remap {name}")
# Remapping scale names. We could do this in the regex above but it
# would triple the number of lines for most layers.
if name.endswith(".qscale_act"):
name = re.sub(r"\.qscale_act$", ".input_scale", name)
elif name.endswith(".qscale_weight"):
name = re.sub(r"\.qscale_weight$", ".weight_scale", name)
return name, loaded_weight

View File

@@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from functools import partial
import torch
import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2DecoderLayer,
DeepseekV2Model,
)
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.mistral_large_3 import MistralLarge3ForCausalLM
from vllm.multimodal.inputs import NestedTensors
from .utils import (
_merge_multimodal_embeddings,
make_empty_intermediate_tensors_factory,
maybe_prefix,
)
logger = init_logger(__name__)
@support_torch_compile
class EagleMistralLarge3Model(DeepseekV2Model):
def __init__(
self, *, vllm_config: VllmConfig, prefix: str = "", start_layer_id: int = 0
):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vllm_config = vllm_config
self.vocab_size = config.vocab_size
assert get_pp_group().world_size == 1
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
self.layers = nn.ModuleList(
[
DeepseekV2DecoderLayer(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
)
for i in range(self.config.num_hidden_layers)
]
)
self.start_layer = 0
self.end_layer = self.config.num_hidden_layers
self.fc = RowParallelLinear(
self.config.hidden_size * 2,
self.config.hidden_size,
bias=False,
input_is_parallel=False,
quant_config=quant_config,
return_bias=False,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_input_ids(input_ids)
inputs_embeds = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1))
output = super().forward(
input_ids, positions, intermediate_tensors=None, inputs_embeds=inputs_embeds
)
assert isinstance(output, torch.Tensor)
return output
class EagleMistralLarge3ForCausalLM(MistralLarge3ForCausalLM):
remapping = MistralLarge3ForCausalLM.remapping | {
r"eagle_linear\.weight": r"model.fc.weight",
r"eagle_linear\.qscale_act": r"model.fc.input_scale",
r"eagle_linear\.qscale_weight": r"model.fc.weight_scale",
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config
)
vllm_config.model_config = vllm_config.speculative_config.draft_model_config
# draft model quantization config may differ from target model
self.quant_config = VllmConfig.get_quantization_config(
vllm_config.speculative_config.draft_model_config, vllm_config.load_config
)
vllm_config.quant_config = self.quant_config
self.model_cls = partial(
EagleMistralLarge3Model, start_layer_id=target_layer_num
)
super().__init__(vllm_config=vllm_config, prefix=prefix)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = super().embed_input_ids(input_ids)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
assert is_multimodal is not None
return _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.model(input_ids, positions, hidden_states, inputs_embeds)
return hidden_states, hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# Pretend we've loaded the embedding and lm_head weights
# (later copied from target model)
return super().load_weights(weights) | {
"model.embed_tokens.weight",
"lm_head.weight",
}
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: NestedTensors | None = None,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

View File

@@ -145,6 +145,7 @@ _TEXT_GENERATION_MODELS = {
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
@@ -424,6 +425,10 @@ _SPECULATIVE_DECODING_MODELS = {
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"EagleMistralLarge3ForCausalLM": (
"mistral_large_3_eagle",
"EagleMistralLarge3ForCausalLM",
),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),