[Llama.py -> mistral.py] Extract mistral-only relevant code into separate file (#32780)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Patrick von Platen
2026-01-22 06:14:57 +01:00
committed by GitHub
parent 889722f3bf
commit 1579c9b5fd
3 changed files with 248 additions and 115 deletions

View File

@@ -152,24 +152,14 @@ class LlamaAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = self.hidden_size // self.total_num_heads
self.head_dim = head_dim
self.head_dim = head_dim or self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
self.do_llama_4_scaling = llama_4_scaling_config is not None
if self.do_llama_4_scaling:
self.llama_4_scaling_original_max_position_embeddings = (
llama_4_scaling_config["original_max_position_embeddings"]
)
self.llama_4_scaling_beta = llama_4_scaling_config["beta"]
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
@@ -229,17 +219,6 @@ class LlamaAttention(nn.Module):
prefix=f"{prefix}.attn",
)
def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
# Llama4 scaling
scaling = 1 + self.llama_4_scaling_beta * torch.log(
1
+ torch.floor(
positions / self.llama_4_scaling_original_max_position_embeddings
)
)
# Broadcast over head_dim
return scaling.unsqueeze(-1)
def forward(
self,
positions: torch.Tensor,
@@ -248,9 +227,6 @@ class LlamaAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
if self.do_llama_4_scaling:
attn_scale = self._get_llama_4_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
@@ -279,6 +255,7 @@ class LlamaDecoderLayer(nn.Module):
vllm_config: VllmConfig,
prefix: str = "",
config: LlamaConfig | None = None,
attn_layer_type: type[nn.Module] = LlamaAttention,
) -> None:
super().__init__()
@@ -307,7 +284,7 @@ class LlamaDecoderLayer(nn.Module):
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = LlamaAttention(
self.self_attn = attn_layer_type(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@@ -537,32 +514,6 @@ class LlamaForCausalLM(
"lm_head": "output_embeddings",
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"qscale_act": "input_scale",
"qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale",
"q_fake_quantizer.qscale_act": "attn.q_scale",
"k_fake_quantizer.qscale_act": "k_scale",
"v_fake_quantizer.qscale_act": "v_scale",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm",
}
def __init__(
self,
*,
@@ -649,67 +600,7 @@ class LlamaForCausalLM(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(
self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights
)
# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def maybe_remap_mistral(
self,
name: str,
loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
attn_in = self.config.head_dim * n_heads
return (
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
.transpose(1, 2)
.reshape(attn_in, attn_out)
)
mapping = self.mistral_mapping
modules = name.split(".")
# rotary embeds should be sliced
# If using quantized model in mistral format,
# quantization scales (qscale_weight) also need to be sliced
if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
)
elif (
"wk" in modules
and modules[-1] == "qscale_weight"
and loaded_weight.numel() > 1
):
loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.config.num_attention_heads, self.config.hidden_size
)
elif (
"wq" in modules
and modules[-1] == "qscale_weight"
and loaded_weight.numel() > 1
):
loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
num_modules = len(modules)
for i in range(num_modules):
item = modules[i]
next_item = modules[i + 1] if i < num_modules - 1 else None
combined_item = f"{item}.{next_item}" if next_item is not None else None
if combined_item in mapping:
name = name.replace(combined_item, mapping[combined_item])
elif item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])
return name, loaded_weight
return loader.load_weights(weights)
class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):

View File

@@ -0,0 +1,242 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Mistral adaptation of the LLaMA architecture."""
from collections.abc import Iterable
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
)
from vllm.v1.attention.backend import AttentionType
from .utils import AutoWeightsLoader
class MistralAttention(LlamaAttention):
def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: CacheConfig | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__(
config=config,
hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=bias,
bias_o_proj=bias_o_proj,
cache_config=cache_config,
prefix=prefix,
attn_type=attn_type,
)
llama_4_scaling_config: dict[str, int | float | str] | None = getattr(
config, "llama_4_scaling", None
)
self.do_llama_4_scaling = llama_4_scaling_config is not None
if self.do_llama_4_scaling:
assert llama_4_scaling_config is not None
self.llama_4_scaling_original_max_position_embeddings = (
llama_4_scaling_config["original_max_position_embeddings"]
)
self.llama_4_scaling_beta = llama_4_scaling_config["beta"]
def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
# Llama4 scaling
scaling = 1 + self.llama_4_scaling_beta * torch.log(
1
+ torch.floor(
positions / self.llama_4_scaling_original_max_position_embeddings
)
)
# Broadcast over head_dim
return scaling.unsqueeze(-1)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
if self.do_llama_4_scaling:
attn_scale = self._get_llama_4_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class MistralDecoderLayer(LlamaDecoderLayer):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
config: LlamaConfig | None = None,
) -> None:
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
config=config,
attn_layer_type=MistralAttention,
)
self.layer_idx = int(prefix.split(sep=".")[-1])
quant_config = self.get_quant_config(vllm_config)
config = config or vllm_config.model_config.hf_config
do_fusion = getattr(
quant_config, "enable_quantization_scaling_fusion", False
) and vllm_config.cache_config.cache_dtype.startswith("fp8")
if do_fusion:
self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj
self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj
@support_torch_compile
class MistralModel(LlamaModel):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = MistralDecoderLayer,
):
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
class MistralForCausalLM(LlamaForCausalLM):
# Mistral: We don't support LoRA on the embedding layers.
embedding_modules: dict[str, str] = {}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"qscale_act": "input_scale",
"qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale",
"q_fake_quantizer.qscale_act": "attn.q_scale",
"k_fake_quantizer.qscale_act": "k_scale",
"v_fake_quantizer.qscale_act": "v_scale",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm",
}
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = MistralDecoderLayer,
):
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
def _init_model(
self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = MistralDecoderLayer,
):
return MistralModel(
vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(
self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights
)
def maybe_remap_mistral(
self,
name: str,
loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
attn_in = self.config.head_dim * n_heads
return (
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
.transpose(1, 2)
.reshape(attn_in, attn_out)
)
mapping = self.mistral_mapping
modules = name.split(".")
# rotary embeds should be sliced
# If using quantized model in mistral format,
# quantization scales (qscale_weight) also need to be sliced
if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
)
elif (
"wk" in modules
and modules[-1] == "qscale_weight"
and loaded_weight.numel() > 1
):
loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.config.num_attention_heads, self.config.hidden_size
)
elif (
"wq" in modules
and modules[-1] == "qscale_weight"
and loaded_weight.numel() > 1
):
loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
num_modules = len(modules)
for i in range(num_modules):
item = modules[i]
next_item = modules[i + 1] if i < num_modules - 1 else None
combined_item = f"{item}.{next_item}" if next_item is not None else None
if combined_item in mapping:
name = name.replace(combined_item, mapping[combined_item])
elif item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])
return name, loaded_weight

View File

@@ -153,7 +153,7 @@ _TEXT_GENERATION_MODELS = {
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
"MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
"MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
# transformers's mpt class has lower case