Add Mistral Large 3 and Ministral 3 (#29757)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
Julien Denize
2025-12-02 11:29:00 +01:00
committed by GitHub
parent 8bbcf8b6e7
commit d8c6210eea
16 changed files with 724 additions and 30 deletions

View File

@@ -395,6 +395,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0
def _get_llama_4_scaling(
original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
) -> torch.Tensor:
scaling = 1 + scaling_beta * torch.log(
1 + torch.floor(positions / original_max_position_embeddings)
)
# Broadcast over num_heads and head_dim
return scaling[..., None, None]
class DeepseekV2Attention(nn.Module):
def __init__(
self,
@@ -481,7 +491,11 @@ class DeepseekV2Attention(nn.Module):
prefix=f"{prefix}.o_proj",
)
if config.rope_parameters["rope_type"] != "default":
config.rope_parameters["rope_type"] = "deepseek_yarn"
config.rope_parameters["rope_type"] = (
"deepseek_yarn"
if config.rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)
self.rotary_emb = get_rope(
qk_rope_head_dim,
@@ -491,7 +505,10 @@ class DeepseekV2Attention(nn.Module):
is_neox_style=False,
)
if config.rope_parameters["rope_type"] != "default":
if (
config.rope_parameters["rope_type"] != "default"
and config.rope_parameters["rope_type"] == "deepseek_yarn"
):
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
@@ -511,6 +528,7 @@ class DeepseekV2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
@@ -536,6 +554,11 @@ class DeepseekV2Attention(nn.Module):
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
# Apply llama 4 scaling if provided
if llama_4_scaling is not None:
q *= llama_4_scaling
# padding value to qk_head_dim for alignment
v = torch.nn.functional.pad(
v, [0, self.qk_head_dim - self.v_head_dim], value=0
@@ -987,7 +1010,12 @@ class DeepseekV2MLAAttention(nn.Module):
)
if config.rope_parameters["rope_type"] != "default":
config.rope_parameters["rope_type"] = "deepseek_yarn"
config.rope_parameters["rope_type"] = (
"deepseek_yarn"
if config.rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
@@ -995,7 +1023,11 @@ class DeepseekV2MLAAttention(nn.Module):
rope_parameters=config.rope_parameters,
is_neox_style=False,
)
if config.rope_parameters["rope_type"] != "default":
if (
config.rope_parameters["rope_type"] != "default"
and config.rope_parameters["rope_type"] == "deepseek_yarn"
):
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
scaling_factor = config.rope_parameters["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
@@ -1064,8 +1096,9 @@ class DeepseekV2MLAAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
) -> torch.Tensor:
return self.mla_attn(positions, hidden_states)
return self.mla_attn(positions, hidden_states, llama_4_scaling)
class DeepseekV2DecoderLayer(nn.Module):
@@ -1155,6 +1188,7 @@ class DeepseekV2DecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
# Self Attention
if residual is None:
@@ -1165,6 +1199,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
llama_4_scaling=llama_4_scaling,
)
if (
@@ -1266,8 +1301,24 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
llama_4_scaling: torch.Tensor | None
if llama_4_scaling_config is not None:
llama_4_scaling = _get_llama_4_scaling(
original_max_position_embeddings=llama_4_scaling_config[
"original_max_position_embeddings"
],
scaling_beta=llama_4_scaling_config["beta"],
positions=positions,
)
else:
llama_4_scaling = None
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, residual = layer(
positions, hidden_states, residual, llama_4_scaling
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -1325,6 +1376,7 @@ class DeepseekV2ForCausalLM(
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
}
model_cls = DeepseekV2Model
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -1355,7 +1407,7 @@ class DeepseekV2ForCausalLM(
"kv_a_proj_with_mqa",
]
self.model = DeepseekV2Model(
self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank: