Enable ModelOpt Llama4 fp8 checkpoint deployment (#20419)
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
@@ -717,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -902,32 +903,109 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
qkv_weight = torch.cat(weight, dim=0)
|
||||
yield key, qkv_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
|
||||
"""Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
|
||||
format."""
|
||||
if name.startswith("model."):
|
||||
# Handle expert scale parameters with flat naming
|
||||
if "feed_forward.experts." in name and ("_input_scale" in name or
|
||||
"_weight_scale" in name):
|
||||
renamed = name.replace("model.", "language_model.model.", 1)
|
||||
# Map checkpoint naming to vLLM's expected naming
|
||||
if "down_proj_input_scale" in renamed:
|
||||
return renamed.replace("down_proj_input_scale",
|
||||
"w2_input_scale")
|
||||
elif "down_proj_weight_scale" in renamed:
|
||||
return renamed.replace("down_proj_weight_scale",
|
||||
"w2_weight_scale")
|
||||
elif "gate_up_proj_input_scale" in renamed:
|
||||
return renamed.replace("gate_up_proj_input_scale",
|
||||
"w13_input_scale")
|
||||
elif "gate_up_proj_weight_scale" in renamed:
|
||||
return renamed.replace("gate_up_proj_weight_scale",
|
||||
"w13_weight_scale")
|
||||
return renamed
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
updated_params: set[str] = set()
|
||||
# Handle attention scale parameters
|
||||
elif "self_attn." in name and (".k_scale" in name
|
||||
or ".v_scale" in name):
|
||||
renamed = name.replace("model.", "language_model.model.", 1)
|
||||
if ".k_proj.k_scale" in renamed:
|
||||
return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
|
||||
elif ".v_proj.v_scale" in renamed:
|
||||
return renamed.replace(".v_proj.v_scale", ".attn.v_scale")
|
||||
return renamed
|
||||
|
||||
# language_model is an Llama4ForCausalLM instance. We load it's
|
||||
# using llama4's load_weights routine.
|
||||
language_model_weights, other_weights = self.separate_weights(
|
||||
weights, prefix="language_model.")
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_language_model_params = loader.load_weights(
|
||||
language_model_weights)
|
||||
assert loaded_language_model_params is not None
|
||||
updated_params.update(loaded_language_model_params)
|
||||
# Standard model.* to language_model.model.* renaming
|
||||
return name.replace("model.", "language_model.model.", 1)
|
||||
|
||||
elif name.startswith("lm_head.weight"):
|
||||
return name.replace("lm_head.weight",
|
||||
"language_model.lm_head.weight")
|
||||
|
||||
return name
|
||||
|
||||
def _separate_and_rename_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]:
|
||||
"""Rename weights and separate them into language_model and other
|
||||
weights."""
|
||||
language_model_weights = []
|
||||
other_weights = []
|
||||
|
||||
for name, weight in weights:
|
||||
renamed = self._rename_weight_for_modelopt_checkpoint(name)
|
||||
|
||||
if renamed.startswith("language_model."):
|
||||
language_model_weights.append((renamed, weight))
|
||||
else:
|
||||
other_weights.append((renamed, weight))
|
||||
|
||||
return language_model_weights, other_weights
|
||||
|
||||
def _handle_expert_scale_broadcasting(
|
||||
self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
|
||||
) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
|
||||
"""Handle expert scale parameters that need broadcasting.
|
||||
|
||||
ModelOpt checkpoints use a single value tensor scalar for BMM style
|
||||
experts, vLLM expects the scale to be broadcasted across all experts.
|
||||
"""
|
||||
regular_weights = []
|
||||
expert_scale_weights = []
|
||||
updated_params = set()
|
||||
|
||||
for name, weight in weights:
|
||||
# Check if this is an expert scale parameter that needs broadcasting
|
||||
if ("feed_forward.experts." in name and "scale" in name
|
||||
and ".shared_expert" not in name):
|
||||
if name in params_dict:
|
||||
param = params_dict[name]
|
||||
if (hasattr(param, 'data') and param.data.numel() > 1
|
||||
and weight.numel() == 1):
|
||||
# Broadcast single value to all experts
|
||||
param.data.fill_(weight.item())
|
||||
updated_params.add(name)
|
||||
continue
|
||||
|
||||
expert_scale_weights.append((name, weight))
|
||||
else:
|
||||
regular_weights.append((name, weight))
|
||||
|
||||
return regular_weights, expert_scale_weights, updated_params
|
||||
|
||||
def _load_other_weights(self, other_weights: Iterable[tuple[str,
|
||||
torch.Tensor]],
|
||||
params_dict: dict,
|
||||
stacked_params_mapping: list) -> set[str]:
|
||||
"""Load non-language-model weights with stacking support."""
|
||||
updated_params = set()
|
||||
|
||||
if self.use_data_parallel:
|
||||
other_weights = self._consolidate_qkv_weights(other_weights)
|
||||
|
||||
for name, loaded_weight in other_weights:
|
||||
# Try stacked parameter mapping first
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name or self.use_data_parallel:
|
||||
continue
|
||||
@@ -938,10 +1016,56 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Use regular weight loading
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
weight_loader(param, loaded_weight)
|
||||
updated_params.add(name)
|
||||
|
||||
return updated_params
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
# Shared expert gate_up_proj stacking
|
||||
(".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
|
||||
(".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
|
||||
# Feed forward gate_up_proj stacking (for non-MoE layers if any)
|
||||
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
|
||||
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
updated_params: set[str] = set()
|
||||
|
||||
# Separate and rename weights
|
||||
language_model_weights, other_weights = (
|
||||
self._separate_and_rename_weights(weights))
|
||||
|
||||
# Handle expert scale parameters
|
||||
regular_weights, expert_scale_weights, updated_params_from_experts = (
|
||||
self._handle_expert_scale_broadcasting(language_model_weights,
|
||||
params_dict))
|
||||
updated_params.update(updated_params_from_experts)
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_language_model_params = loader.load_weights(regular_weights)
|
||||
assert loaded_language_model_params is not None
|
||||
updated_params.update(loaded_language_model_params)
|
||||
|
||||
if expert_scale_weights:
|
||||
loaded_expert_scale_params = loader.load_weights(
|
||||
expert_scale_weights)
|
||||
if loaded_expert_scale_params:
|
||||
updated_params.update(loaded_expert_scale_params)
|
||||
|
||||
updated_params.update(
|
||||
self._load_other_weights(other_weights, params_dict,
|
||||
stacked_params_mapping))
|
||||
|
||||
return updated_params
|
||||
|
||||
Reference in New Issue
Block a user