Refactor Arctic loading to use AutoWeightsLoader (#38955)
Signed-off-by: Lalit Laxminarayan Bangad <lalitbangad@gmail.com> Co-authored-by: Lalit Laxminarayan Bangad <lalitbangad@meta.com>
This commit is contained in:
@@ -16,7 +16,6 @@ from vllm.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.attention import Attention
|
from vllm.model_executor.layers.attention import Attention
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||||
@@ -42,6 +41,7 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig
|
|||||||
|
|
||||||
from .interfaces import SupportsPP, SupportsQuant
|
from .interfaces import SupportsPP, SupportsQuant
|
||||||
from .utils import (
|
from .utils import (
|
||||||
|
AutoWeightsLoader,
|
||||||
extract_layer_index,
|
extract_layer_index,
|
||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory,
|
make_empty_intermediate_tensors_factory,
|
||||||
@@ -49,8 +49,6 @@ from .utils import (
|
|||||||
maybe_prefix,
|
maybe_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ArcticMLP(nn.Module):
|
class ArcticMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -384,6 +382,7 @@ class ArcticModel(nn.Module):
|
|||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size
|
self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size
|
||||||
@@ -426,6 +425,116 @@ class ArcticModel(nn.Module):
|
|||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
|
mlp_params_mapping: list[tuple[str, str, int]] = []
|
||||||
|
expert_params_mapping: list[tuple[str, str, int]] = []
|
||||||
|
|
||||||
|
for layer in range(self.config.num_hidden_layers):
|
||||||
|
is_moe_layer = (layer + 1) % self.config.moe_layer_frequency == 0
|
||||||
|
if is_moe_layer and self.config.use_residual:
|
||||||
|
mlp_params_mapping.append(
|
||||||
|
(
|
||||||
|
f"layers.{layer}.residual_mlp.w13.weight",
|
||||||
|
f"layers.{layer}.residual_mlp.w1.weight",
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mlp_params_mapping.append(
|
||||||
|
(
|
||||||
|
f"layers.{layer}.residual_mlp.w13.weight",
|
||||||
|
f"layers.{layer}.residual_mlp.w3.weight",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_moe_layer:
|
||||||
|
for expert_id in range(self.config.num_local_experts):
|
||||||
|
expert_params_mapping.append(
|
||||||
|
("ws", f"experts.{expert_id}.w1.weight", expert_id)
|
||||||
|
)
|
||||||
|
expert_params_mapping.append(
|
||||||
|
("w2s", f"experts.{expert_id}.w2.weight", expert_id)
|
||||||
|
)
|
||||||
|
expert_params_mapping.append(
|
||||||
|
("ws", f"experts.{expert_id}.w3.weight", expert_id)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mlp_params_mapping.append(
|
||||||
|
(
|
||||||
|
f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
|
||||||
|
f"layers.{layer}.block_sparse_moe.mlp.w1.weight",
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mlp_params_mapping.append(
|
||||||
|
(
|
||||||
|
f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
|
||||||
|
f"layers.{layer}.block_sparse_moe.mlp.w3.weight",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for param_name, weight_name, shard_id in mlp_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
for param_name, weight_name, shard_id in expert_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(
|
||||||
|
param, loaded_weight, weight_name, expert_id=shard_id
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(
|
||||||
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
||||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||||
@@ -478,117 +587,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
loader = AutoWeightsLoader(
|
||||||
# (param_name, shard_name, shard_id)
|
self,
|
||||||
("qkv_proj", "q_proj", "q"),
|
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||||
("qkv_proj", "k_proj", "k"),
|
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
]
|
|
||||||
|
|
||||||
mlp_params_mapping: list[tuple[str, str, int]] = []
|
|
||||||
expert_params_mapping: list[tuple[str, str, int]] = []
|
|
||||||
num_layers = self.config.num_hidden_layers
|
|
||||||
|
|
||||||
for layer in range(num_layers):
|
|
||||||
mlp_params_mapping.append(
|
|
||||||
(
|
|
||||||
f"layers.{layer}.residual_mlp.w13.weight",
|
|
||||||
f"layers.{layer}.residual_mlp.w1.weight",
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
mlp_params_mapping.append(
|
|
||||||
(
|
|
||||||
f"layers.{layer}.residual_mlp.w13.weight",
|
|
||||||
f"layers.{layer}.residual_mlp.w3.weight",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if layer % 2 == 0:
|
|
||||||
# MLP layers
|
|
||||||
mlp_params_mapping.append(
|
|
||||||
(
|
|
||||||
f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
|
|
||||||
f"layers.{layer}.block_sparse_moe.mlp.w1.weight",
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
mlp_params_mapping.append(
|
|
||||||
(
|
|
||||||
f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
|
|
||||||
f"layers.{layer}.block_sparse_moe.mlp.w3.weight",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# MoE layers
|
|
||||||
for expert_id in range(self.config.num_local_experts):
|
|
||||||
expert_params_mapping.append(
|
|
||||||
("ws", f"experts.{expert_id}.w1.weight", expert_id)
|
|
||||||
)
|
|
||||||
expert_params_mapping.append(
|
|
||||||
("w2s", f"experts.{expert_id}.w2.weight", expert_id)
|
|
||||||
)
|
|
||||||
expert_params_mapping.append(
|
|
||||||
("ws", f"experts.{expert_id}.w3.weight", expert_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
loaded_params: set[str] = set()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"It will take ~10 minutes loading from the 16-bit weights. "
|
|
||||||
"Alternatively, use the prequantized 8-bit weights of arctic "
|
|
||||||
"and set load-format to `sharded_state` will accelerate loading."
|
|
||||||
)
|
)
|
||||||
for name, loaded_weight in weights:
|
return loader.load_weights(weights)
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
for param_name, weight_name, shard_id in mlp_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
for param_name, weight_name, shard_id in expert_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(
|
|
||||||
param, loaded_weight, weight_name, expert_id=shard_id
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
|
|
||||||
weight_loader = getattr(
|
|
||||||
param, "weight_loader", default_weight_loader
|
|
||||||
)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
|||||||
Reference in New Issue
Block a user