Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -19,7 +19,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
"""PyTorch Starcoder2 model."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
@@ -33,31 +34,43 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from .utils import (
AutoWeightsLoader,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
class Starcoder2Attention(nn.Module):
def __init__(self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@@ -107,13 +120,15 @@ class Starcoder2Attention(nn.Module):
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
@@ -129,11 +144,12 @@ class Starcoder2Attention(nn.Module):
class Starcoder2MLP(nn.Module):
def __init__(self,
config: Starcoder2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
config: Starcoder2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.c_fc = ColumnParallelLinear(
config.hidden_size,
@@ -159,25 +175,28 @@ class Starcoder2MLP(nn.Module):
class Starcoder2DecoderLayer(nn.Module):
def __init__(self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Starcoder2Attention(config,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.mlp = Starcoder2MLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_epsilon)
self.self_attn = Starcoder2Attention(
config,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = Starcoder2MLP(
config, quant_config=quant_config, prefix=f"{prefix}.mlp"
)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.norm_epsilon
)
def forward(
self,
@@ -204,7 +223,6 @@ class Starcoder2DecoderLayer(nn.Module):
@support_torch_compile
class Starcoder2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -219,7 +237,8 @@ class Starcoder2Model(nn.Module):
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
prefix=f"{prefix}.embed_tokens",
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Starcoder2DecoderLayer(
@@ -228,9 +247,9 @@ class Starcoder2Model(nn.Module):
prefix=f"{prefix}.layers",
)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], config.hidden_size
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -257,8 +276,7 @@ class Starcoder2Model(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -269,7 +287,7 @@ class Starcoder2Model(nn.Module):
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@@ -286,22 +304,21 @@ class Starcoder2Model(nn.Module):
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.model = Starcoder2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.model = Starcoder2Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings:
@@ -316,10 +333,12 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
quant_config=quant_config,
prefix=f"{prefix}.lm_head",
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@@ -331,8 +350,9 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def compute_logits(
@@ -342,13 +362,13 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
skip_prefixes=(["lm_head.weight"]
if self.config.tie_word_embeddings else None),
skip_prefixes=(
["lm_head.weight"] if self.config.tie_word_embeddings else None
),
)
return loader.load_weights(weights)