Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -23,8 +23,10 @@ import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import set_weight_attrs
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
|
||||
@@ -32,7 +34,6 @@ from .utils import AutoWeightsLoader, WeightsMapper
|
||||
|
||||
|
||||
class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
@@ -45,14 +46,12 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
|
||||
f"model.{self.tp_rank}.pt",
|
||||
]
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
# fairseq2's serialization adds a wrapper to usual .pt state_dict's:
|
||||
# { "model_key": my_model_name, "my_model_name": state_dict }
|
||||
# which we first need to unpack
|
||||
weights_wrapped = dict(weights)
|
||||
weights = weights_wrapped[
|
||||
weights_wrapped["model_key"]].items() # type: ignore
|
||||
weights = weights_wrapped[weights_wrapped["model_key"]].items() # type: ignore
|
||||
|
||||
# remap keys
|
||||
fs2_to_vllm_mapper = WeightsMapper(
|
||||
@@ -77,12 +76,14 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(
|
||||
(self.reshape_fairseq2_weights(name, loaded_weight, params)
|
||||
for name, loaded_weight in weights))
|
||||
(
|
||||
self.reshape_fairseq2_weights(name, loaded_weight, params)
|
||||
for name, loaded_weight in weights
|
||||
)
|
||||
)
|
||||
|
||||
def flag_sharded_weights(self, params: dict[str, Parameter]):
|
||||
"""Sets the `is_sharded_weight` flag to True for all sharded weights"""
|
||||
@@ -113,35 +114,34 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
|
||||
attn_in //= self.tp_size
|
||||
n_heads //= self.tp_size
|
||||
attn_out = self.config.hidden_size
|
||||
return (w.view(n_heads, attn_in // n_heads // 2, 2,
|
||||
attn_out).transpose(1,
|
||||
2).reshape(attn_in, attn_out))
|
||||
return (
|
||||
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
||||
.transpose(1, 2)
|
||||
.reshape(attn_in, attn_out)
|
||||
)
|
||||
|
||||
modules = name.split(".")
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if "k_proj" in modules:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_key_value_heads)
|
||||
loaded_weight = permute(loaded_weight, self.config.num_key_value_heads)
|
||||
|
||||
elif "q_proj" in modules:
|
||||
loaded_weight = permute(loaded_weight,
|
||||
self.config.num_attention_heads)
|
||||
loaded_weight = permute(loaded_weight, self.config.num_attention_heads)
|
||||
|
||||
# We make the loaded weights compatible with both
|
||||
# full checkpoints and tp sharded checkpoints.
|
||||
# Embeddings are repeated to fit the vocab size.
|
||||
# Other weights are flagged for the weight_loader calls.
|
||||
# Other weights are flagged for the weight_loader calls.
|
||||
if any(emb in modules for emb in ["embed_tokens", "lm_head"]):
|
||||
# Embeddings are sharded on dim 0
|
||||
dim = 0
|
||||
# In fairseq2, vocab size has to be divisible by tp_size
|
||||
# so we don't worry about padding
|
||||
if self.tp_size > 1 and loaded_weight.shape[
|
||||
dim] < self.config.vocab_size:
|
||||
assert loaded_weight.shape[
|
||||
dim] * self.tp_size == self.config.vocab_size, \
|
||||
"vocab_size should be divisible by tp_size."
|
||||
if self.tp_size > 1 and loaded_weight.shape[dim] < self.config.vocab_size:
|
||||
assert (
|
||||
loaded_weight.shape[dim] * self.tp_size == self.config.vocab_size
|
||||
), "vocab_size should be divisible by tp_size."
|
||||
repeats = [1] * len(loaded_weight.size())
|
||||
repeats[dim] = self.tp_size
|
||||
# repeat to match vocab size and to be easily 'narrow'able
|
||||
|
||||
Reference in New Issue
Block a user