Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -23,6 +23,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BailingMoE model compatible with HuggingFace weights."""
from collections.abc import Iterable
from itertools import islice
from typing import Optional, Union
@@ -35,31 +36,42 @@ from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
class BailingAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
@@ -79,8 +91,7 @@ class BailingAttention(nn.Module):
assert self.total_num_heads >= self.total_kv_heads
self.num_heads = self.total_num_heads // tp_size
self.head_dim = config.head_dim or (self.hidden_size //
self.total_num_heads)
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
self.q_size_per_rank = self.head_dim * self.num_heads
self.num_kv_heads = self.total_kv_heads // tp_size
self.kv_size_per_rank = self.num_kv_heads * self.head_dim
@@ -99,12 +110,16 @@ class BailingAttention(nn.Module):
)
if self.use_qk_norm:
self.query_layernorm = (RMSNorm(
self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm
else nn.LayerNorm(self.head_dim, eps=1e-6))
self.key_layernorm = (RMSNorm(
self.head_dim, eps=config.rms_norm_eps) if self.use_rmsnorm
else nn.LayerNorm(self.head_dim, eps=1e-6))
self.query_layernorm = (
RMSNorm(self.head_dim, eps=config.rms_norm_eps)
if self.use_rmsnorm
else nn.LayerNorm(self.head_dim, eps=1e-6)
)
self.key_layernorm = (
RMSNorm(self.head_dim, eps=config.rms_norm_eps)
if self.use_rmsnorm
else nn.LayerNorm(self.head_dim, eps=1e-6)
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
@@ -115,8 +130,7 @@ class BailingAttention(nn.Module):
prefix=f"{prefix}.dense",
)
self.partial_rotary_factor = getattr(config, "partial_rotary_factor",
1.0)
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
self.rotary_dim = getattr(config, "rotary_dim", self.head_dim)
@@ -144,12 +158,10 @@ class BailingAttention(nn.Module):
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([
self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank
],
dim=-1)
q, k, v = qkv.split(
[self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], dim=-1
)
if self.use_qk_norm:
q = q.view(-1, self.num_heads, self.head_dim)
@@ -168,7 +180,6 @@ class BailingAttention(nn.Module):
class BailingMLP(nn.Module):
def __init__(
self,
intermediate_size: int,
@@ -203,7 +214,6 @@ class BailingMLP(nn.Module):
class BailingMoE(nn.Module):
def __init__(
self,
intermediate_size: int,
@@ -225,10 +235,8 @@ class BailingMoE(nn.Module):
self.score_function = getattr(config, "score_function", None)
self.n_group = getattr(config, "n_group", None)
self.topk_group = getattr(config, "topk_group", None)
self.use_grouped_topk = (self.n_group is not None
and self.topk_group is not None)
self.routed_scaling_factor = getattr(config, "routed_scaling_factor",
1.0)
self.use_grouped_topk = self.n_group is not None and self.topk_group is not None
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
router_dtype = getattr(config, "router_dtype", None)
if router_dtype is None:
@@ -247,21 +255,23 @@ class BailingMoE(nn.Module):
if getattr(config, "moe_router_enable_expert_bias", False):
self.gate.expert_bias = nn.Parameter(
torch.empty((config.num_experts, ), dtype=torch.float32))
torch.empty((config.num_experts,), dtype=torch.float32)
)
else:
self.gate.expert_bias = None
self.correction_bias = (self.gate.expert_bias.data
if self.gate.expert_bias is not None else None)
self.correction_bias = (
self.gate.expert_bias.data if self.gate.expert_bias is not None else None
)
if self.score_function is not None:
assert (
self.score_function == "softmax"
and self.correction_bias is None
self.score_function == "softmax" and self.correction_bias is None
) or (
self.score_function == "sigmoid"
and self.correction_bias is not None
), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" # noqa: E501
self.score_function == "sigmoid" and self.correction_bias is not None
), (
"score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)"
) # noqa: E501
else:
# default value for scoring_func
self.score_function = "softmax"
@@ -293,7 +303,8 @@ class BailingMoE(nn.Module):
config=config,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts")
prefix=f"{prefix}.shared_experts",
)
else:
self.shared_experts = None
@@ -306,8 +317,9 @@ class BailingMoE(nn.Module):
router_logits = self.gate(hidden_states.to(self.router_dtype))
router_logits = router_logits.to(hidden_states.dtype)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
final_hidden_states *= self.routed_scaling_factor
@@ -315,13 +327,11 @@ class BailingMoE(nn.Module):
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
class BailingMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
@@ -330,30 +340,26 @@ class BailingMoeBlock(nn.Module):
prefix: str = "",
):
super().__init__()
layer_idx = int(prefix.split('.')[-1])
layer_idx = int(prefix.split(".")[-1])
self.config = config
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
self.attention = BailingAttention(config,
cache_config,
quant_config,
prefix=f"{prefix}.attention")
self.attention = BailingAttention(
config, cache_config, quant_config, prefix=f"{prefix}.attention"
)
self.post_attention_layernorm = RMSNorm(hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
# Choose MLP class based on the number of experts and layer index
if layer_idx < config.first_k_dense_replace:
mlp_class = BailingMLP
else:
mlp_class = BailingMoE
self.mlp = mlp_class(intermediate_size,
config,
quant_config,
True,
prefix=f"{prefix}.mlp")
self.mlp = mlp_class(
intermediate_size, config, quant_config, True, prefix=f"{prefix}.mlp"
)
def forward(
self,
@@ -365,23 +371,20 @@ class BailingMoeBlock(nn.Module):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.attention(
hidden_states=hidden_states,
position_ids=position_ids,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class BailingMoeModel(nn.Module):
def __init__(
self,
*,
@@ -396,11 +399,11 @@ class BailingMoeModel(nn.Module):
self.config = config
self.vocab_size = config.vocab_size
self.embed_dim = config.hidden_size
self.tie_word_embeddings = getattr(config, "tie_word_embeddings",
False)
self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
if get_pp_group().is_first_rank or (self.tie_word_embeddings
and get_pp_group().is_last_rank):
if get_pp_group().is_first_rank or (
self.tie_word_embeddings and get_pp_group().is_last_rank
):
self.word_embeddings = VocabParallelEmbedding(
self.vocab_size,
self.embed_dim,
@@ -420,11 +423,12 @@ class BailingMoeModel(nn.Module):
quant_config=quant_config,
prefix=prefix,
),
prefix=f"{prefix}.layers")
prefix=f"{prefix}.layers",
)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
@@ -460,10 +464,9 @@ class BailingMoeModel(nn.Module):
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
else:
if residual is None:
hidden_states = self.norm(hidden_states)
@@ -479,8 +482,7 @@ class BailingMoeModel(nn.Module):
num_experts=self.config.num_experts,
)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
@@ -491,14 +493,14 @@ class BailingMoeModel(nn.Module):
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if (hasattr(self.config, "norm_head") and self.config.norm_head
and "lm_head.weight" in name):
loaded_weight = F.normalize(loaded_weight,
dim=0,
p=2,
eps=1e-7)
if (
hasattr(self.config, "norm_head")
and self.config.norm_head
and "lm_head.weight" in name
):
loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
@@ -548,15 +550,15 @@ class BailingMoeModel(nn.Module):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"gate_up_proj": [
@@ -582,10 +584,10 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.lora_config = lora_config
self.quant_config = quant_config
self.max_position_embeddings = config.max_position_embeddings
self.model = BailingMoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.tie_word_embeddings = getattr(config, "tie_word_embeddings",
False)
self.model = BailingMoeModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
if get_pp_group().is_last_rank:
if self.tie_word_embeddings:
@@ -602,7 +604,8 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.model.make_empty_intermediate_tensors
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@@ -614,8 +617,9 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
model_output = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return model_output
def compute_logits(
@@ -625,8 +629,7 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.tie_word_embeddings else None),