[Model] use AutoWeightsLoader for granite, granitemoe, granitemoeshared, grok1, mixtral (#16325)
Signed-off-by: Aaron Ang <aaron.angyd@gmail.com>
This commit is contained in:
@@ -48,7 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@@ -302,6 +302,8 @@ class Grok1Model(nn.Module):
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
@@ -370,6 +372,105 @@ class Grok1Model(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
# Map Grok1's unique expert parameter names to standard names
|
||||
# Grok1 uses "num_experts" in its config
|
||||
num_experts = getattr(self.config, "num_experts", 8)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="linear", # Grok1 specific
|
||||
ckpt_down_proj_name="linear_1", # Grok1 specific
|
||||
ckpt_up_proj_name="linear_v", # Grok1 specific
|
||||
num_experts=num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name.endswith("scale"):
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
# Handle Grok1-specific norm.scale naming
|
||||
if "norm.scale" in name:
|
||||
name = name.replace("scale", "weight")
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
fall_back_to_pt_during_load = False
|
||||
@@ -460,106 +561,10 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
skip_prefixes = ["rotary_emb.inv_freq"]
|
||||
# Skip lm_head when tie_word_embeddings is True
|
||||
if self.config.tie_word_embeddings:
|
||||
skip_prefixes.append("lm_head")
|
||||
|
||||
# Map Grok1's unique expert parameter names to standard names
|
||||
# Grok1 uses "num_experts" in its config
|
||||
num_experts = getattr(self.config, "num_experts", 8)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="linear", # Grok1 specific
|
||||
ckpt_down_proj_name="linear_1", # Grok1 specific
|
||||
ckpt_up_proj_name="linear_v", # Grok1 specific
|
||||
num_experts=num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name.endswith("scale"):
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
# Handle Grok1-specific norm.scale naming
|
||||
if "norm.scale" in name:
|
||||
name = name.replace("scale", "weight")
|
||||
|
||||
# Skip lm_head when tie_word_embeddings is True
|
||||
if "lm_head" in name and self.config.tie_word_embeddings:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user