Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -17,6 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Any, Optional
@@ -28,27 +29,36 @@ from vllm.attention import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
is_pp_missing_parameter)
from .utils import (
AutoWeightsLoader,
extract_layer_index,
fast_topk,
is_pp_missing_parameter,
)
class Llama4MoE(nn.Module):
@staticmethod
def custom_routing_function(
hidden_states: torch.Tensor,
@@ -73,11 +83,13 @@ class Llama4MoE(nn.Module):
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(config.hidden_size,
config.num_local_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.router")
self.router = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.router",
)
self.shared_expert = LlamaMLP(
hidden_size=config.hidden_size,
@@ -123,26 +135,28 @@ class Llama4MoE(nn.Module):
experts_out = experts_out[:num_tokens]
elif self.tp_size > 1:
experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
experts_out)
experts_out
)
return experts_out
class Llama4Attention(nn.Module):
def __init__(self,
config: Llama4TextConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "") -> None:
def __init__(
self,
config: Llama4TextConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
@@ -167,20 +181,23 @@ class Llama4Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.attn_temperature_tuning = self.nope and \
config.attn_temperature_tuning
self.attn_temperature_tuning = self.nope and config.attn_temperature_tuning
self.floor_scale = getattr(config, "floor_scale", 8192.0)
self.attn_scale = getattr(config, "attn_scale", 0.1)
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.n_rep = self.num_heads // self.num_kv_heads
self.qk_norm = RMSNorm(
hidden_size=self.head_dim,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
) if self.use_qk_norm else None
self.qk_norm = (
RMSNorm(
hidden_size=self.head_dim,
eps=config.rms_norm_eps,
has_weight=False,
dtype=torch.float32,
)
if self.use_qk_norm
else None
)
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
@@ -203,18 +220,21 @@ class Llama4Attention(nn.Module):
if is_gguf and config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
rope_scaling=rope_scaling if rope_scaling != "default" else None,
is_neox_style=is_neox_style,
) if not self.nope else None
self.rotary_emb = (
get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
rope_scaling=rope_scaling if rope_scaling != "default" else None,
is_neox_style=is_neox_style,
)
if not self.nope
else None
)
use_chunked_local_attn = not self.nope and config.attention_chunk_size
attn_cls = (ChunkedLocalAttention
if use_chunked_local_attn else Attention)
attn_cls = ChunkedLocalAttention if use_chunked_local_attn else Attention
self.attn = attn_cls(
self.num_heads,
self.head_dim,
@@ -223,9 +243,12 @@ class Llama4Attention(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
**({
"attention_chunk_size": config.attention_chunk_size
} if use_chunked_local_attn else {}))
**(
{"attention_chunk_size": config.attention_chunk_size}
if use_chunked_local_attn
else {}
),
)
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
floor = torch.floor((positions + 1.0) / self.floor_scale)
@@ -270,11 +293,12 @@ class Llama4Attention(nn.Module):
class Llama4DecoderLayer(nn.Module):
def __init__(self,
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[Llama4TextConfig] = None) -> None:
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[Llama4TextConfig] = None,
) -> None:
super().__init__()
config = config or vllm_config.model_config.hf_config
@@ -302,8 +326,10 @@ class Llama4DecoderLayer(nn.Module):
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
is_moe_layer = config.interleave_moe_layer_step > 0 and (
self.layer_idx + 1) % config.interleave_moe_layer_step == 0
is_moe_layer = (
config.interleave_moe_layer_step > 0
and (self.layer_idx + 1) % config.interleave_moe_layer_step == 0
)
if is_moe_layer:
self.feed_forward = Llama4MoE(
vllm_config=vllm_config,
@@ -318,10 +344,10 @@ class Llama4DecoderLayer(nn.Module):
bias=False,
prefix=f"{prefix}.feed_forward",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
@@ -334,30 +360,26 @@ class Llama4DecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
@support_torch_compile
class Llama4Model(LlamaModel):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
):
self.num_experts = vllm_config.model_config.hf_config.num_local_experts
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type)
def load_moe_expert_weights(
self,
@@ -408,9 +430,7 @@ class Llama4Model(LlamaModel):
# Iterate over all the expert parameters and load the weights if we find
# a match in weight name.
for (param_name, weight_name, expert_id,
shard_id) in expert_params_mapping:
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
# Get a view of the loaded_weight to avoid modifying the original
# one across iterations.
new_loaded_weight = loaded_weight
@@ -419,7 +439,7 @@ class Llama4Model(LlamaModel):
# the expert index from the expected weight name.
if fused:
# The string between e_str and proj_str is the expert index.
e_str, _, proj_str, _ = weight_name.split('.')
e_str, _, proj_str, _ = weight_name.split(".")
weight_name = f"{e_str}.{proj_str}"
param_name = f"{param_name}weight"
@@ -436,8 +456,9 @@ class Llama4Model(LlamaModel):
continue
# Skip if the current weight is for the bias.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
if (
name.endswith(".bias") or name.endswith("_bias")
) and name not in params_dict:
continue
param = params_dict[full_param_name]
@@ -456,13 +477,14 @@ class Llama4Model(LlamaModel):
# starting expert index for the current EP rank and extract the
# corresponding expert weights.
layer_idx = extract_layer_index(name)
expert_map = self.layers[
layer_idx].feed_forward.experts.expert_map
expert_map = self.layers[layer_idx].feed_forward.experts.expert_map
if expert_map is not None:
local_expert_indices = (expert_map != -1) \
.nonzero() \
.flatten() \
.to(new_loaded_weight.device)
local_expert_indices = (
(expert_map != -1)
.nonzero()
.flatten()
.to(new_loaded_weight.device)
)
new_loaded_weight = new_loaded_weight[local_expert_indices]
expert_id = local_expert_indices[0].item()
else:
@@ -471,19 +493,20 @@ class Llama4Model(LlamaModel):
# Load the weight into the module parameter with corresponding
# shard id and expert id.
weight_loader(param,
new_loaded_weight,
full_param_name,
shard_id=shard_id,
expert_id=expert_id)
weight_loader(
param,
new_loaded_weight,
full_param_name,
shard_id=shard_id,
expert_id=expert_id,
)
loaded_params.add(full_param_name)
expert_param_loaded = True
return expert_param_loaded
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# Name mapping from the parameter name to the shard name and
# corresponding shard id.
stacked_params_mapping = [
@@ -503,14 +526,16 @@ class Llama4Model(LlamaModel):
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.num_experts)
num_experts=self.num_experts,
)
# Expert parameter mapping for the case where the expert weights are
# fused into a single weight tensor.
expert_params_mapping_fused = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="gate_up_proj",
num_experts=1)
num_experts=1,
)
# All the module parameters.
params_dict = dict(self.named_parameters())
# The module parameters that have been loaded.
@@ -518,7 +543,6 @@ class Llama4Model(LlamaModel):
# Iterate over all the weights and load them into module parameters.
for name, loaded_weight in weights:
# If the name contains "experts.gate_up_proj" or "experts.down_proj"
# without the expert indices, it means the expert weights are fused
# into a single weight tensor across all experts.
@@ -529,13 +553,14 @@ class Llama4Model(LlamaModel):
# If kv cache quantization scales exist and the weight name
# corresponds to one of the kv cache quantization scales, load
# them.
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
@@ -552,8 +577,9 @@ class Llama4Model(LlamaModel):
# For ModelOpt checkpoints, we need to rename the self_attn
# weight/weight_scale names except for kv cache scales.
if not (name.endswith(
(".k_scale", ".v_scale")) and "self_attn" in name):
if not (
name.endswith((".k_scale", ".v_scale")) and "self_attn" in name
):
name = name.replace(weight_name, param_name)
# Skip if the current weight corresponds to a parameter that
@@ -572,8 +598,7 @@ class Llama4Model(LlamaModel):
# Load the weight into the module parameter with corresponding
# shard id and exit the for loop and the else block.
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
@@ -587,12 +612,14 @@ class Llama4Model(LlamaModel):
else:
# First, try to load MoE weights using load_moe_expert_weights.
# If successful, move on to next loaded weight.
if self.load_moe_expert_weights(name,
loaded_weight,
params_dict,
loaded_params,
expert_params_mapping,
fused=fused_experts_params):
if self.load_moe_expert_weights(
name,
loaded_weight,
params_dict,
loaded_params,
expert_params_mapping,
fused=fused_experts_params,
):
continue
# Skip if the current weight corresponds to a parameter that
@@ -604,37 +631,40 @@ class Llama4Model(LlamaModel):
# per-expert patterns, i.e. one weight scale tensor for all
# experts.
scale_names = [
"w13_input_scale", "w13_weight_scale", "w2_input_scale",
"w2_weight_scale"
"w13_input_scale",
"w13_weight_scale",
"w2_input_scale",
"w2_weight_scale",
]
if ("experts." in name and any(scale_name in name
for scale_name in scale_names)):
if "experts." in name and any(
scale_name in name for scale_name in scale_names
):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
# If weight loader supports special moe loading, use it to
# avoid expensive runtime reflection
if getattr(weight_loader, 'supports_moe_loading', False):
if getattr(weight_loader, "supports_moe_loading", False):
# Map the weight name to the corresponding shard id.
shard_id = "w2" if "w2_" in name else "w1"
# Transpose if weight scales are FP8 block scales with
# three dimensions:
# [num_experts, hidden_in, hidden_out].
if name.endswith("weight_scale") \
and loaded_weight.dtype == torch.float8_e4m3fn \
and loaded_weight.ndim == 3:
if (
name.endswith("weight_scale")
and loaded_weight.dtype == torch.float8_e4m3fn
and loaded_weight.ndim == 3
):
loaded_weight = loaded_weight.transpose(-1, -2)
# Load the weight into the module parameter with
# corresponding shard id and expert id.
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=0)
weight_loader(
param, loaded_weight, name, shard_id=shard_id, expert_id=0
)
else:
# Regular weight loader (handles both
@@ -646,8 +676,7 @@ class Llama4Model(LlamaModel):
# Handle normal (non-stacked, non-MoE) weights.
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
@@ -656,7 +685,6 @@ class Llama4Model(LlamaModel):
class Llama4ForCausalLM(LlamaForCausalLM):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -667,30 +695,29 @@ class Llama4ForCausalLM(LlamaForCausalLM):
gen_config = vllm_config.model_config.try_get_generation_config()
gen_config.update(vllm_config.model_config.override_generation_config)
# enable temperature tuning by default when max_model_len > 32K
default_attn_temperature_tuning = \
vllm_config.model_config.max_model_len > 32768
vllm_config.model_config.hf_config.attn_temperature_tuning \
= gen_config.get(
"attn_temperature_tuning", default_attn_temperature_tuning)
default_attn_temperature_tuning = vllm_config.model_config.max_model_len > 32768
vllm_config.model_config.hf_config.attn_temperature_tuning = gen_config.get(
"attn_temperature_tuning", default_attn_temperature_tuning
)
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=Llama4DecoderLayer)
super().__init__(
vllm_config=vllm_config, prefix=prefix, layer_type=Llama4DecoderLayer
)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer):
return Llama4Model(vllm_config=vllm_config,
prefix=prefix,
layer_type=layer_type)
def _init_model(
self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[Llama4DecoderLayer] = Llama4DecoderLayer,
):
return Llama4Model(
vllm_config=vllm_config, prefix=prefix, layer_type=layer_type
)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
weights = [
self.permute_qk_weight_for_rotary(name, loaded_weight)
@@ -703,10 +730,8 @@ class Llama4ForCausalLM(LlamaForCausalLM):
name: str,
loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:
# Helper function to permute the weight's channels
def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):
# Calculate the expected shape of the weight.
# Do not rely on w's shape, as it may be in another layout.
attn_in = self.config.head_dim * n_heads
@@ -719,28 +744,39 @@ class Llama4ForCausalLM(LlamaForCausalLM):
# If the weight is a weight scale, we need to divide attn_out by
# block size, which is currently 16.
elif w.dtype == torch.float8_e4m3fn and is_weight_scale \
and w.shape[1] * 16 == attn_out:
elif (
w.dtype == torch.float8_e4m3fn
and is_weight_scale
and w.shape[1] * 16 == attn_out
):
attn_out = attn_out // 16
return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
return (
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
.transpose(1, 2)
.reshape(attn_in, attn_out)
)
modules = name.split(".")
# Permute Q/K weights and weight block scales for rotary embedding
is_weight = modules[-1] == "weight"
is_nvfp4_weight_scale = (modules[-1] == "weight_scale" and
loaded_weight.dtype == torch.float8_e4m3fn)
is_nvfp4_weight_scale = (
modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn
)
if is_weight or is_nvfp4_weight_scale:
if ("wk" in modules or "k_proj" in modules):
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads,
is_nvfp4_weight_scale)
elif ("wq" in modules or "q_proj" in modules):
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads,
is_nvfp4_weight_scale)
if "wk" in modules or "k_proj" in modules:
loaded_weight = permute(
loaded_weight,
self.config.num_key_value_heads,
is_nvfp4_weight_scale,
)
elif "wq" in modules or "q_proj" in modules:
loaded_weight = permute(
loaded_weight,
self.config.num_attention_heads,
is_nvfp4_weight_scale,
)
return name, loaded_weight