Refactor Arctic loading to use AutoWeightsLoader (#38955)
Signed-off-by: Lalit Laxminarayan Bangad <lalitbangad@gmail.com> Co-authored-by: Lalit Laxminarayan Bangad <lalitbangad@meta.com>
This commit is contained in:
@@ -16,7 +16,6 @@ from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
@@ -42,6 +41,7 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig
|
||||
|
||||
from .interfaces import SupportsPP, SupportsQuant
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
@@ -49,8 +49,6 @@ from .utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ArcticMLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -384,6 +382,7 @@ class ArcticModel(nn.Module):
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size
|
||||
@@ -426,6 +425,116 @@ class ArcticModel(nn.Module):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
mlp_params_mapping: list[tuple[str, str, int]] = []
|
||||
expert_params_mapping: list[tuple[str, str, int]] = []
|
||||
|
||||
for layer in range(self.config.num_hidden_layers):
|
||||
is_moe_layer = (layer + 1) % self.config.moe_layer_frequency == 0
|
||||
if is_moe_layer and self.config.use_residual:
|
||||
mlp_params_mapping.append(
|
||||
(
|
||||
f"layers.{layer}.residual_mlp.w13.weight",
|
||||
f"layers.{layer}.residual_mlp.w1.weight",
|
||||
0,
|
||||
)
|
||||
)
|
||||
mlp_params_mapping.append(
|
||||
(
|
||||
f"layers.{layer}.residual_mlp.w13.weight",
|
||||
f"layers.{layer}.residual_mlp.w3.weight",
|
||||
1,
|
||||
)
|
||||
)
|
||||
|
||||
if is_moe_layer:
|
||||
for expert_id in range(self.config.num_local_experts):
|
||||
expert_params_mapping.append(
|
||||
("ws", f"experts.{expert_id}.w1.weight", expert_id)
|
||||
)
|
||||
expert_params_mapping.append(
|
||||
("w2s", f"experts.{expert_id}.w2.weight", expert_id)
|
||||
)
|
||||
expert_params_mapping.append(
|
||||
("ws", f"experts.{expert_id}.w3.weight", expert_id)
|
||||
)
|
||||
else:
|
||||
mlp_params_mapping.append(
|
||||
(
|
||||
f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
|
||||
f"layers.{layer}.block_sparse_moe.mlp.w1.weight",
|
||||
0,
|
||||
)
|
||||
)
|
||||
mlp_params_mapping.append(
|
||||
(
|
||||
f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
|
||||
f"layers.{layer}.block_sparse_moe.mlp.w3.weight",
|
||||
1,
|
||||
)
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for param_name, weight_name, shard_id in mlp_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for param_name, weight_name, shard_id in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param, loaded_weight, weight_name, expert_id=shard_id
|
||||
)
|
||||
break
|
||||
else:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
@@ -478,117 +587,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
mlp_params_mapping: list[tuple[str, str, int]] = []
|
||||
expert_params_mapping: list[tuple[str, str, int]] = []
|
||||
num_layers = self.config.num_hidden_layers
|
||||
|
||||
for layer in range(num_layers):
|
||||
mlp_params_mapping.append(
|
||||
(
|
||||
f"layers.{layer}.residual_mlp.w13.weight",
|
||||
f"layers.{layer}.residual_mlp.w1.weight",
|
||||
0,
|
||||
)
|
||||
)
|
||||
mlp_params_mapping.append(
|
||||
(
|
||||
f"layers.{layer}.residual_mlp.w13.weight",
|
||||
f"layers.{layer}.residual_mlp.w3.weight",
|
||||
1,
|
||||
)
|
||||
)
|
||||
if layer % 2 == 0:
|
||||
# MLP layers
|
||||
mlp_params_mapping.append(
|
||||
(
|
||||
f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
|
||||
f"layers.{layer}.block_sparse_moe.mlp.w1.weight",
|
||||
0,
|
||||
)
|
||||
)
|
||||
mlp_params_mapping.append(
|
||||
(
|
||||
f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
|
||||
f"layers.{layer}.block_sparse_moe.mlp.w3.weight",
|
||||
1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# MoE layers
|
||||
for expert_id in range(self.config.num_local_experts):
|
||||
expert_params_mapping.append(
|
||||
("ws", f"experts.{expert_id}.w1.weight", expert_id)
|
||||
)
|
||||
expert_params_mapping.append(
|
||||
("w2s", f"experts.{expert_id}.w2.weight", expert_id)
|
||||
)
|
||||
expert_params_mapping.append(
|
||||
("ws", f"experts.{expert_id}.w3.weight", expert_id)
|
||||
)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
logger.info(
|
||||
"It will take ~10 minutes loading from the 16-bit weights. "
|
||||
"Alternatively, use the prequantized 8-bit weights of arctic "
|
||||
"and set load-format to `sharded_state` will accelerate loading."
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for param_name, weight_name, shard_id in mlp_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for param_name, weight_name, shard_id in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param, loaded_weight, weight_name, expert_id=shard_id
|
||||
)
|
||||
break
|
||||
else:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
return loader.load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user