From 93726b2a1c766a27f10ad7c3f709ffbacb94bcf9 Mon Sep 17 00:00:00 2001 From: lalit10 Date: Fri, 3 Apr 2026 22:01:09 -0700 Subject: [PATCH] Refactor Arctic loading to use AutoWeightsLoader (#38955) Signed-off-by: Lalit Laxminarayan Bangad Co-authored-by: Lalit Laxminarayan Bangad --- vllm/model_executor/models/arctic.py | 232 +++++++++++++-------------- 1 file changed, 116 insertions(+), 116 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 031b6534f..0c9267994 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -16,7 +16,6 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk @@ -42,6 +41,7 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig from .interfaces import SupportsPP, SupportsQuant from .utils import ( + AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, @@ -49,8 +49,6 @@ from .utils import ( maybe_prefix, ) -logger = init_logger(__name__) - class ArcticMLP(nn.Module): def __init__( @@ -384,6 +382,7 @@ class ArcticModel(nn.Module): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config + self.config = config self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size @@ -426,6 +425,116 @@ class ArcticModel(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + mlp_params_mapping: list[tuple[str, str, int]] = [] + expert_params_mapping: list[tuple[str, str, int]] = [] + + for layer in range(self.config.num_hidden_layers): + is_moe_layer = (layer + 1) % self.config.moe_layer_frequency == 0 + if is_moe_layer and self.config.use_residual: + mlp_params_mapping.append( + ( + f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w1.weight", + 0, + ) + ) + mlp_params_mapping.append( + ( + f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w3.weight", + 1, + ) + ) + + if is_moe_layer: + for expert_id in range(self.config.num_local_experts): + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w1.weight", expert_id) + ) + expert_params_mapping.append( + ("w2s", f"experts.{expert_id}.w2.weight", expert_id) + ) + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w3.weight", expert_id) + ) + else: + mlp_params_mapping.append( + ( + f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w1.weight", + 0, + ) + ) + mlp_params_mapping.append( + ( + f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w3.weight", + 1, + ) + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id in mlp_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, loaded_weight, weight_name, expert_id=shard_id + ) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} @@ -478,117 +587,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant): return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - - mlp_params_mapping: list[tuple[str, str, int]] = [] - expert_params_mapping: list[tuple[str, str, int]] = [] - num_layers = self.config.num_hidden_layers - - for layer in range(num_layers): - mlp_params_mapping.append( - ( - f"layers.{layer}.residual_mlp.w13.weight", - f"layers.{layer}.residual_mlp.w1.weight", - 0, - ) - ) - mlp_params_mapping.append( - ( - f"layers.{layer}.residual_mlp.w13.weight", - f"layers.{layer}.residual_mlp.w3.weight", - 1, - ) - ) - if layer % 2 == 0: - # MLP layers - mlp_params_mapping.append( - ( - f"layers.{layer}.block_sparse_moe.mlp.w13.weight", - f"layers.{layer}.block_sparse_moe.mlp.w1.weight", - 0, - ) - ) - mlp_params_mapping.append( - ( - f"layers.{layer}.block_sparse_moe.mlp.w13.weight", - f"layers.{layer}.block_sparse_moe.mlp.w3.weight", - 1, - ) - ) - else: - # MoE layers - for expert_id in range(self.config.num_local_experts): - expert_params_mapping.append( - ("ws", f"experts.{expert_id}.w1.weight", expert_id) - ) - expert_params_mapping.append( - ("w2s", f"experts.{expert_id}.w2.weight", expert_id) - ) - expert_params_mapping.append( - ("ws", f"experts.{expert_id}.w3.weight", expert_id) - ) - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - logger.info( - "It will take ~10 minutes loading from the 16-bit weights. " - "Alternatively, use the prequantized 8-bit weights of arctic " - "and set load-format to `sharded_state` will accelerate loading." + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for param_name, weight_name, shard_id in mlp_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for param_name, weight_name, shard_id in expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, loaded_weight, weight_name, expert_id=shard_id - ) - break - else: - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + return loader.load_weights(weights)