Add Mistral Large 3 and Ministral 3 (#29757)
Signed-off-by: Julien Denize <julien.denize@mistral.ai> Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Signed-off-by: Mickael Seznec <mickael@mistral.ai> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Mickael Seznec <mickael@mistral.ai>
This commit is contained in:
@@ -395,6 +395,16 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def _get_llama_4_scaling(
|
||||
original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
scaling = 1 + scaling_beta * torch.log(
|
||||
1 + torch.floor(positions / original_max_position_embeddings)
|
||||
)
|
||||
# Broadcast over num_heads and head_dim
|
||||
return scaling[..., None, None]
|
||||
|
||||
|
||||
class DeepseekV2Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -481,7 +491,11 @@ class DeepseekV2Attention(nn.Module):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
if config.rope_parameters["rope_type"] != "default":
|
||||
config.rope_parameters["rope_type"] = "deepseek_yarn"
|
||||
config.rope_parameters["rope_type"] = (
|
||||
"deepseek_yarn"
|
||||
if config.rope_parameters.get("apply_yarn_scaling", True)
|
||||
else "deepseek_llama_scaling"
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
@@ -491,7 +505,10 @@ class DeepseekV2Attention(nn.Module):
|
||||
is_neox_style=False,
|
||||
)
|
||||
|
||||
if config.rope_parameters["rope_type"] != "default":
|
||||
if (
|
||||
config.rope_parameters["rope_type"] != "default"
|
||||
and config.rope_parameters["rope_type"] == "deepseek_yarn"
|
||||
):
|
||||
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
|
||||
scaling_factor = config.rope_parameters["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
@@ -511,6 +528,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
llama_4_scaling: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
q = self.q_a_proj(hidden_states)[0]
|
||||
@@ -536,6 +554,11 @@ class DeepseekV2Attention(nn.Module):
|
||||
k = torch.empty_like(q)
|
||||
k[..., : self.qk_nope_head_dim] = k_nope
|
||||
k[..., self.qk_nope_head_dim :] = k_pe
|
||||
|
||||
# Apply llama 4 scaling if provided
|
||||
if llama_4_scaling is not None:
|
||||
q *= llama_4_scaling
|
||||
|
||||
# padding value to qk_head_dim for alignment
|
||||
v = torch.nn.functional.pad(
|
||||
v, [0, self.qk_head_dim - self.v_head_dim], value=0
|
||||
@@ -987,7 +1010,12 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
)
|
||||
|
||||
if config.rope_parameters["rope_type"] != "default":
|
||||
config.rope_parameters["rope_type"] = "deepseek_yarn"
|
||||
config.rope_parameters["rope_type"] = (
|
||||
"deepseek_yarn"
|
||||
if config.rope_parameters.get("apply_yarn_scaling", True)
|
||||
else "deepseek_llama_scaling"
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
@@ -995,7 +1023,11 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=False,
|
||||
)
|
||||
if config.rope_parameters["rope_type"] != "default":
|
||||
|
||||
if (
|
||||
config.rope_parameters["rope_type"] != "default"
|
||||
and config.rope_parameters["rope_type"] == "deepseek_yarn"
|
||||
):
|
||||
mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
|
||||
scaling_factor = config.rope_parameters["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
@@ -1064,8 +1096,9 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
llama_4_scaling: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
return self.mla_attn(positions, hidden_states)
|
||||
return self.mla_attn(positions, hidden_states, llama_4_scaling)
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
@@ -1155,6 +1188,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
llama_4_scaling: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
@@ -1165,6 +1199,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
llama_4_scaling=llama_4_scaling,
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -1266,8 +1301,24 @@ class DeepseekV2Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
# Compute llama 4 scaling once per forward pass if enabled
|
||||
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
|
||||
llama_4_scaling: torch.Tensor | None
|
||||
if llama_4_scaling_config is not None:
|
||||
llama_4_scaling = _get_llama_4_scaling(
|
||||
original_max_position_embeddings=llama_4_scaling_config[
|
||||
"original_max_position_embeddings"
|
||||
],
|
||||
scaling_beta=llama_4_scaling_config["beta"],
|
||||
positions=positions,
|
||||
)
|
||||
else:
|
||||
llama_4_scaling = None
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, residual, llama_4_scaling
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
@@ -1325,6 +1376,7 @@ class DeepseekV2ForCausalLM(
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
model_cls = DeepseekV2Model
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@@ -1355,7 +1407,7 @@ class DeepseekV2ForCausalLM(
|
||||
"kv_a_proj_with_mqa",
|
||||
]
|
||||
|
||||
self.model = DeepseekV2Model(
|
||||
self.model = self.model_cls(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
|
||||
Reference in New Issue
Block a user