Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -23,8 +23,10 @@ import torch
from torch.nn import Parameter
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.linear import set_weight_attrs
from vllm.model_executor.models.llama import LlamaForCausalLM
@@ -32,7 +34,6 @@ from .utils import AutoWeightsLoader, WeightsMapper
class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.tp_rank = get_tensor_model_parallel_rank()
@@ -45,14 +46,12 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
f"model.{self.tp_rank}.pt",
]
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# fairseq2's serialization adds a wrapper to usual .pt state_dict's:
# { "model_key": my_model_name, "my_model_name": state_dict }
# which we first need to unpack
weights_wrapped = dict(weights)
weights = weights_wrapped[
weights_wrapped["model_key"]].items() # type: ignore
weights = weights_wrapped[weights_wrapped["model_key"]].items() # type: ignore
# remap keys
fs2_to_vllm_mapper = WeightsMapper(
@@ -77,12 +76,14 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(
(self.reshape_fairseq2_weights(name, loaded_weight, params)
for name, loaded_weight in weights))
(
self.reshape_fairseq2_weights(name, loaded_weight, params)
for name, loaded_weight in weights
)
)
def flag_sharded_weights(self, params: dict[str, Parameter]):
"""Sets the `is_sharded_weight` flag to True for all sharded weights"""
@@ -113,35 +114,34 @@ class Fairseq2LlamaForCausalLM(LlamaForCausalLM):
attn_in //= self.tp_size
n_heads //= self.tp_size
attn_out = self.config.hidden_size
return (w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1,
2).reshape(attn_in, attn_out))
return (
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
.transpose(1, 2)
.reshape(attn_in, attn_out)
)
modules = name.split(".")
# rotary embeds should be sliced
if "k_proj" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
loaded_weight = permute(loaded_weight, self.config.num_key_value_heads)
elif "q_proj" in modules:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
loaded_weight = permute(loaded_weight, self.config.num_attention_heads)
# We make the loaded weights compatible with both
# full checkpoints and tp sharded checkpoints.
# Embeddings are repeated to fit the vocab size.
# Other weights are flagged for the weight_loader calls.
# Other weights are flagged for the weight_loader calls.
if any(emb in modules for emb in ["embed_tokens", "lm_head"]):
# Embeddings are sharded on dim 0
dim = 0
# In fairseq2, vocab size has to be divisible by tp_size
# so we don't worry about padding
if self.tp_size > 1 and loaded_weight.shape[
dim] < self.config.vocab_size:
assert loaded_weight.shape[
dim] * self.tp_size == self.config.vocab_size, \
"vocab_size should be divisible by tp_size."
if self.tp_size > 1 and loaded_weight.shape[dim] < self.config.vocab_size:
assert (
loaded_weight.shape[dim] * self.tp_size == self.config.vocab_size
), "vocab_size should be divisible by tp_size."
repeats = [1] * len(loaded_weight.size())
repeats[dim] = self.tp_size
# repeat to match vocab size and to be easily 'narrow'able