[Llama.py -> mistral.py] Extract mistral-only relevant code into separate file (#32780)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
889722f3bf
commit
1579c9b5fd
@@ -152,24 +152,14 @@ class LlamaAttention(nn.Module):
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
if head_dim is None:
|
||||
head_dim = self.hidden_size // self.total_num_heads
|
||||
self.head_dim = head_dim
|
||||
self.head_dim = head_dim or self.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
|
||||
self.do_llama_4_scaling = llama_4_scaling_config is not None
|
||||
if self.do_llama_4_scaling:
|
||||
self.llama_4_scaling_original_max_position_embeddings = (
|
||||
llama_4_scaling_config["original_max_position_embeddings"]
|
||||
)
|
||||
self.llama_4_scaling_beta = llama_4_scaling_config["beta"]
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=self.head_dim,
|
||||
@@ -229,17 +219,6 @@ class LlamaAttention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||
# Llama4 scaling
|
||||
scaling = 1 + self.llama_4_scaling_beta * torch.log(
|
||||
1
|
||||
+ torch.floor(
|
||||
positions / self.llama_4_scaling_original_max_position_embeddings
|
||||
)
|
||||
)
|
||||
# Broadcast over head_dim
|
||||
return scaling.unsqueeze(-1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -248,9 +227,6 @@ class LlamaAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
if self.do_llama_4_scaling:
|
||||
attn_scale = self._get_llama_4_attn_scale(positions)
|
||||
q = (q * attn_scale).to(q.dtype)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
@@ -279,6 +255,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: LlamaConfig | None = None,
|
||||
attn_layer_type: type[nn.Module] = LlamaAttention,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -307,7 +284,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
self.self_attn = LlamaAttention(
|
||||
self.self_attn = attn_layer_type(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@@ -537,32 +514,6 @@ class LlamaForCausalLM(
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
|
||||
# Mistral/Llama models can also be loaded with --load-format mistral
|
||||
# from consolidated.safetensors checkpoints
|
||||
mistral_mapping = {
|
||||
"layers": "model.layers",
|
||||
"attention": "self_attn",
|
||||
"qscale_act": "input_scale",
|
||||
"qscale_weight": "weight_scale",
|
||||
"kv_fake_quantizer.qscale_act": "kv_scale",
|
||||
"q_fake_quantizer.qscale_act": "attn.q_scale",
|
||||
"k_fake_quantizer.qscale_act": "k_scale",
|
||||
"v_fake_quantizer.qscale_act": "v_scale",
|
||||
"wq": "q_proj",
|
||||
"wk": "k_proj",
|
||||
"wv": "v_proj",
|
||||
"wo": "o_proj",
|
||||
"attention_norm": "input_layernorm",
|
||||
"feed_forward": "mlp",
|
||||
"w1": "gate_proj",
|
||||
"w2": "down_proj",
|
||||
"w3": "up_proj",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"tok_embeddings": "model.embed_tokens",
|
||||
"output": "lm_head",
|
||||
"norm": "model.norm",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -649,67 +600,7 @@ class LlamaForCausalLM(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(
|
||||
self.maybe_remap_mistral(name, loaded_weight)
|
||||
for name, loaded_weight in weights
|
||||
)
|
||||
|
||||
# This function is used to remap the mistral format as
|
||||
# used by Mistral and Llama <=2
|
||||
def maybe_remap_mistral(
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
|
||||
return (
|
||||
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
||||
.transpose(1, 2)
|
||||
.reshape(attn_in, attn_out)
|
||||
)
|
||||
|
||||
mapping = self.mistral_mapping
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
# If using quantized model in mistral format,
|
||||
# quantization scales (qscale_weight) also need to be sliced
|
||||
if "wk" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(
|
||||
loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
|
||||
)
|
||||
elif (
|
||||
"wk" in modules
|
||||
and modules[-1] == "qscale_weight"
|
||||
and loaded_weight.numel() > 1
|
||||
):
|
||||
loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
|
||||
elif "wq" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(
|
||||
loaded_weight, self.config.num_attention_heads, self.config.hidden_size
|
||||
)
|
||||
elif (
|
||||
"wq" in modules
|
||||
and modules[-1] == "qscale_weight"
|
||||
and loaded_weight.numel() > 1
|
||||
):
|
||||
loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
|
||||
|
||||
num_modules = len(modules)
|
||||
for i in range(num_modules):
|
||||
item = modules[i]
|
||||
next_item = modules[i + 1] if i < num_modules - 1 else None
|
||||
|
||||
combined_item = f"{item}.{next_item}" if next_item is not None else None
|
||||
|
||||
if combined_item in mapping:
|
||||
name = name.replace(combined_item, mapping[combined_item])
|
||||
elif item in mapping and mapping[item] not in name:
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
|
||||
|
||||
242
vllm/model_executor/models/mistral.py
Normal file
242
vllm/model_executor/models/mistral.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Mistral adaptation of the LLaMA architecture."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionType
|
||||
|
||||
from .utils import AutoWeightsLoader
|
||||
|
||||
|
||||
class MistralAttention(LlamaAttention):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
bias_o_proj: bool = False,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config=config,
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=bias,
|
||||
bias_o_proj=bias_o_proj,
|
||||
cache_config=cache_config,
|
||||
prefix=prefix,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
|
||||
llama_4_scaling_config: dict[str, int | float | str] | None = getattr(
|
||||
config, "llama_4_scaling", None
|
||||
)
|
||||
self.do_llama_4_scaling = llama_4_scaling_config is not None
|
||||
if self.do_llama_4_scaling:
|
||||
assert llama_4_scaling_config is not None
|
||||
self.llama_4_scaling_original_max_position_embeddings = (
|
||||
llama_4_scaling_config["original_max_position_embeddings"]
|
||||
)
|
||||
self.llama_4_scaling_beta = llama_4_scaling_config["beta"]
|
||||
|
||||
def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||
# Llama4 scaling
|
||||
scaling = 1 + self.llama_4_scaling_beta * torch.log(
|
||||
1
|
||||
+ torch.floor(
|
||||
positions / self.llama_4_scaling_original_max_position_embeddings
|
||||
)
|
||||
)
|
||||
# Broadcast over head_dim
|
||||
return scaling.unsqueeze(-1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
if self.do_llama_4_scaling:
|
||||
attn_scale = self._get_llama_4_attn_scale(positions)
|
||||
q = (q * attn_scale).to(q.dtype)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class MistralDecoderLayer(LlamaDecoderLayer):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: LlamaConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
config=config,
|
||||
attn_layer_type=MistralAttention,
|
||||
)
|
||||
|
||||
self.layer_idx = int(prefix.split(sep=".")[-1])
|
||||
quant_config = self.get_quant_config(vllm_config)
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
|
||||
do_fusion = getattr(
|
||||
quant_config, "enable_quantization_scaling_fusion", False
|
||||
) and vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||
if do_fusion:
|
||||
self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj
|
||||
self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MistralModel(LlamaModel):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = MistralDecoderLayer,
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
|
||||
|
||||
|
||||
class MistralForCausalLM(LlamaForCausalLM):
|
||||
# Mistral: We don't support LoRA on the embedding layers.
|
||||
embedding_modules: dict[str, str] = {}
|
||||
|
||||
# Mistral/Llama models can also be loaded with --load-format mistral
|
||||
# from consolidated.safetensors checkpoints
|
||||
mistral_mapping = {
|
||||
"layers": "model.layers",
|
||||
"attention": "self_attn",
|
||||
"qscale_act": "input_scale",
|
||||
"qscale_weight": "weight_scale",
|
||||
"kv_fake_quantizer.qscale_act": "kv_scale",
|
||||
"q_fake_quantizer.qscale_act": "attn.q_scale",
|
||||
"k_fake_quantizer.qscale_act": "k_scale",
|
||||
"v_fake_quantizer.qscale_act": "v_scale",
|
||||
"wq": "q_proj",
|
||||
"wk": "k_proj",
|
||||
"wv": "v_proj",
|
||||
"wo": "o_proj",
|
||||
"attention_norm": "input_layernorm",
|
||||
"feed_forward": "mlp",
|
||||
"w1": "gate_proj",
|
||||
"w2": "down_proj",
|
||||
"w3": "up_proj",
|
||||
"ffn_norm": "post_attention_layernorm",
|
||||
"tok_embeddings": "model.embed_tokens",
|
||||
"output": "lm_head",
|
||||
"norm": "model.norm",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = MistralDecoderLayer,
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
|
||||
|
||||
def _init_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = MistralDecoderLayer,
|
||||
):
|
||||
return MistralModel(
|
||||
vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(
|
||||
self.maybe_remap_mistral(name, loaded_weight)
|
||||
for name, loaded_weight in weights
|
||||
)
|
||||
|
||||
def maybe_remap_mistral(
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
|
||||
return (
|
||||
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
||||
.transpose(1, 2)
|
||||
.reshape(attn_in, attn_out)
|
||||
)
|
||||
|
||||
mapping = self.mistral_mapping
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
# If using quantized model in mistral format,
|
||||
# quantization scales (qscale_weight) also need to be sliced
|
||||
if "wk" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(
|
||||
loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
|
||||
)
|
||||
elif (
|
||||
"wk" in modules
|
||||
and modules[-1] == "qscale_weight"
|
||||
and loaded_weight.numel() > 1
|
||||
):
|
||||
loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
|
||||
elif "wq" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(
|
||||
loaded_weight, self.config.num_attention_heads, self.config.hidden_size
|
||||
)
|
||||
elif (
|
||||
"wq" in modules
|
||||
and modules[-1] == "qscale_weight"
|
||||
and loaded_weight.numel() > 1
|
||||
):
|
||||
loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
|
||||
|
||||
num_modules = len(modules)
|
||||
for i in range(num_modules):
|
||||
item = modules[i]
|
||||
next_item = modules[i + 1] if i < num_modules - 1 else None
|
||||
|
||||
combined_item = f"{item}.{next_item}" if next_item is not None else None
|
||||
|
||||
if combined_item in mapping:
|
||||
name = name.replace(combined_item, mapping[combined_item])
|
||||
elif item in mapping and mapping[item] not in name:
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
@@ -153,7 +153,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
||||
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
|
||||
"MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"),
|
||||
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
|
||||
"MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"),
|
||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||
# transformers's mpt class has lower case
|
||||
|
||||
Reference in New Issue
Block a user