Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -23,6 +23,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Any, Optional, Union
|
||||
@@ -38,27 +39,38 @@ from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -89,8 +101,9 @@ class LlamaMLP(nn.Module):
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
@@ -101,7 +114,6 @@ class LlamaMLP(nn.Module):
|
||||
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
@@ -141,8 +153,7 @@ class LlamaAttention(nn.Module):
|
||||
head_dim = self.hidden_size // self.total_num_heads
|
||||
self.head_dim = head_dim
|
||||
# Phi models introduced a partial_rotary_factor parameter in the config
|
||||
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
|
||||
1)
|
||||
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
@@ -167,33 +178,36 @@ class LlamaAttention(nn.Module):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self._init_rotary_emb(config,
|
||||
rope_scaling=rope_scaling,
|
||||
quant_config=quant_config)
|
||||
self._init_rotary_emb(
|
||||
config, rope_scaling=rope_scaling, quant_config=quant_config
|
||||
)
|
||||
|
||||
sliding_window = None
|
||||
if layer_types := getattr(config, "layer_types", None):
|
||||
# Fix for Eagle3 compatibility:
|
||||
# for draft models, subtract target layer count
|
||||
# to get draft-relative layer index starting from 0
|
||||
if hasattr(config, 'target_layer_count'):
|
||||
if hasattr(config, "target_layer_count"):
|
||||
# This is a draft model,
|
||||
# adjust layer_idx to be relative to draft layers
|
||||
effective_layer_idx = layer_idx - config.target_layer_count
|
||||
else:
|
||||
# This is a target model, use layer_idx directly
|
||||
effective_layer_idx = layer_idx
|
||||
assert effective_layer_idx < len(layer_types), \
|
||||
assert effective_layer_idx < len(layer_types), (
|
||||
f"effective_layer_idx: {effective_layer_idx} \
|
||||
is out of bounds for layer_types: {layer_types}"
|
||||
)
|
||||
|
||||
is_sliding = layer_types[
|
||||
effective_layer_idx] == "sliding_attention"
|
||||
is_sliding = layer_types[effective_layer_idx] == "sliding_attention"
|
||||
if is_sliding:
|
||||
sliding_window = config.sliding_window
|
||||
|
||||
attn_cls = (EncoderOnlyAttention
|
||||
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
||||
attn_cls = (
|
||||
EncoderOnlyAttention
|
||||
if attn_type == AttentionType.ENCODER_ONLY
|
||||
else Attention
|
||||
)
|
||||
|
||||
self.attn = attn_cls(
|
||||
self.num_heads,
|
||||
@@ -219,9 +233,12 @@ class LlamaAttention(nn.Module):
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def _init_rotary_emb(self, config: LlamaConfig,
|
||||
rope_scaling: Optional[dict[str, Any]],
|
||||
quant_config: Optional[QuantizationConfig]) -> None:
|
||||
def _init_rotary_emb(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
rope_scaling: Optional[dict[str, Any]],
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
) -> None:
|
||||
is_neox_style = True
|
||||
is_gguf = quant_config and quant_config.get_name() == "gguf"
|
||||
if is_gguf and config.model_type == "llama":
|
||||
@@ -239,11 +256,12 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[LlamaConfig] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[LlamaConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
@@ -254,18 +272,20 @@ class LlamaDecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
config, "original_max_position_embeddings", None
|
||||
):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
config.original_max_position_embeddings
|
||||
)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
||||
# Support internlm/internlm-7b with bias
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False)
|
||||
config, "bias", False
|
||||
)
|
||||
bias_o_proj = attention_bias
|
||||
# support internlm/internlm3-8b with qkv_bias
|
||||
if hasattr(config, 'qkv_bias'):
|
||||
if hasattr(config, "qkv_bias"):
|
||||
attention_bias = config.qkv_bias
|
||||
|
||||
# By default, Llama uses causal attention as it is a decoder-only model.
|
||||
@@ -281,8 +301,9 @@ class LlamaDecoderLayer(nn.Module):
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
config.num_attention_heads),
|
||||
num_kv_heads=getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
@@ -301,10 +322,10 @@ class LlamaDecoderLayer(nn.Module):
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -317,31 +338,28 @@ class LlamaDecoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states)
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
def get_quant_config(
|
||||
self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
|
||||
def get_quant_config(self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
|
||||
"""Get quantization config for this layer. Override in subclasses."""
|
||||
return vllm_config.quant_config
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = LlamaDecoderLayer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = LlamaDecoderLayer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
@@ -350,12 +368,16 @@ class LlamaModel(nn.Module):
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
if get_pp_group().is_first_rank or (config.tie_word_embeddings
|
||||
and get_pp_group().is_last_rank):
|
||||
if get_pp_group().is_first_rank or (
|
||||
config.tie_word_embeddings and get_pp_group().is_last_rank
|
||||
):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
@@ -376,9 +398,9 @@ class LlamaModel(nn.Module):
|
||||
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@@ -389,8 +411,9 @@ class LlamaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
|
||||
list[torch.Tensor]]]:
|
||||
) -> Union[
|
||||
torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]]
|
||||
]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
@@ -404,16 +427,16 @@ class LlamaModel(nn.Module):
|
||||
|
||||
aux_hidden_states = []
|
||||
for idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)):
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
@@ -421,8 +444,7 @@ class LlamaModel(nn.Module):
|
||||
return hidden_states, aux_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -436,19 +458,19 @@ class LlamaModel(nn.Module):
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = (
|
||||
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
@@ -481,8 +503,7 @@ class LlamaModel(nn.Module):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
@@ -491,13 +512,13 @@ class LlamaModel(nn.Module):
|
||||
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings"
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
@@ -527,11 +548,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
"norm": "model.norm",
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = LlamaDecoderLayer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = LlamaDecoderLayer,
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
@@ -539,9 +562,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = self._init_model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
layer_type=layer_type)
|
||||
self.model = self._init_model(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
layer_type=layer_type,
|
||||
)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
@@ -555,24 +580,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config else
|
||||
lora_config.lora_vocab_padding_size),
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size
|
||||
),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(
|
||||
self.model.embed_tokens)
|
||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logit_scale)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size, logit_scale
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
@@ -581,13 +607,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def _init_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = LlamaDecoderLayer):
|
||||
return LlamaModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
layer_type=layer_type)
|
||||
def _init_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = LlamaDecoderLayer,
|
||||
):
|
||||
return LlamaModel(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
@@ -599,8 +625,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
model_output = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return model_output
|
||||
|
||||
def compute_logits(
|
||||
@@ -610,16 +637,15 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(
|
||||
self.maybe_remap_mistral(name, loaded_weight)
|
||||
for name, loaded_weight in weights)
|
||||
for name, loaded_weight in weights
|
||||
)
|
||||
|
||||
# This function is used to remap the mistral format as
|
||||
# used by Mistral and Llama <=2
|
||||
@@ -628,12 +654,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
|
||||
return w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
|
||||
return (
|
||||
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
||||
.transpose(1, 2)
|
||||
.reshape(attn_in, attn_out)
|
||||
)
|
||||
|
||||
mapping = self.mistral_mapping
|
||||
modules = name.split(".")
|
||||
@@ -642,29 +670,32 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
# If using quantized model in mistral format,
|
||||
# quantization scales (qscale_weight) also need to be sliced
|
||||
if "wk" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads,
|
||||
self.config.hidden_size)
|
||||
elif "wk" in modules and modules[
|
||||
-1] == "qscale_weight" and loaded_weight.numel() > 1:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads, 1)
|
||||
loaded_weight = permute(
|
||||
loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
|
||||
)
|
||||
elif (
|
||||
"wk" in modules
|
||||
and modules[-1] == "qscale_weight"
|
||||
and loaded_weight.numel() > 1
|
||||
):
|
||||
loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
|
||||
elif "wq" in modules and modules[-1] == "weight":
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads,
|
||||
self.config.hidden_size)
|
||||
elif "wq" in modules and modules[
|
||||
-1] == "qscale_weight" and loaded_weight.numel() > 1:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads, 1)
|
||||
loaded_weight = permute(
|
||||
loaded_weight, self.config.num_attention_heads, self.config.hidden_size
|
||||
)
|
||||
elif (
|
||||
"wq" in modules
|
||||
and modules[-1] == "qscale_weight"
|
||||
and loaded_weight.numel() > 1
|
||||
):
|
||||
loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)
|
||||
|
||||
num_modules = len(modules)
|
||||
for i in range(num_modules):
|
||||
item = modules[i]
|
||||
next_item = modules[i + 1] if i < num_modules - 1 else None
|
||||
|
||||
combined_item = (f"{item}.{next_item}"
|
||||
if next_item is not None else None)
|
||||
combined_item = f"{item}.{next_item}" if next_item is not None else None
|
||||
|
||||
if combined_item in mapping:
|
||||
name = name.replace(combined_item, mapping[combined_item])
|
||||
|
||||
Reference in New Issue
Block a user