Enable ModelOpt Llama4 fp8 checkpoint deployment (#20419)

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
Zhiyu
2025-07-11 23:07:16 -07:00
committed by GitHub
parent 5de8d9f111
commit 4afe687a82
5 changed files with 501 additions and 35 deletions

View File

@@ -717,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
@classmethod
@@ -902,32 +903,109 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
qkv_weight = torch.cat(weight, dim=0)
yield key, qkv_weight
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str:
"""Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
format."""
if name.startswith("model."):
# Handle expert scale parameters with flat naming
if "feed_forward.experts." in name and ("_input_scale" in name or
"_weight_scale" in name):
renamed = name.replace("model.", "language_model.model.", 1)
# Map checkpoint naming to vLLM's expected naming
if "down_proj_input_scale" in renamed:
return renamed.replace("down_proj_input_scale",
"w2_input_scale")
elif "down_proj_weight_scale" in renamed:
return renamed.replace("down_proj_weight_scale",
"w2_weight_scale")
elif "gate_up_proj_input_scale" in renamed:
return renamed.replace("gate_up_proj_input_scale",
"w13_input_scale")
elif "gate_up_proj_weight_scale" in renamed:
return renamed.replace("gate_up_proj_weight_scale",
"w13_weight_scale")
return renamed
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
# Handle attention scale parameters
elif "self_attn." in name and (".k_scale" in name
or ".v_scale" in name):
renamed = name.replace("model.", "language_model.model.", 1)
if ".k_proj.k_scale" in renamed:
return renamed.replace(".k_proj.k_scale", ".attn.k_scale")
elif ".v_proj.v_scale" in renamed:
return renamed.replace(".v_proj.v_scale", ".attn.v_scale")
return renamed
# language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine.
language_model_weights, other_weights = self.separate_weights(
weights, prefix="language_model.")
loader = AutoWeightsLoader(self)
loaded_language_model_params = loader.load_weights(
language_model_weights)
assert loaded_language_model_params is not None
updated_params.update(loaded_language_model_params)
# Standard model.* to language_model.model.* renaming
return name.replace("model.", "language_model.model.", 1)
elif name.startswith("lm_head.weight"):
return name.replace("lm_head.weight",
"language_model.lm_head.weight")
return name
def _separate_and_rename_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]:
"""Rename weights and separate them into language_model and other
weights."""
language_model_weights = []
other_weights = []
for name, weight in weights:
renamed = self._rename_weight_for_modelopt_checkpoint(name)
if renamed.startswith("language_model."):
language_model_weights.append((renamed, weight))
else:
other_weights.append((renamed, weight))
return language_model_weights, other_weights
def _handle_expert_scale_broadcasting(
self, weights: list[tuple[str, torch.Tensor]], params_dict: dict
) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
"""Handle expert scale parameters that need broadcasting.
ModelOpt checkpoints use a single value tensor scalar for BMM style
experts, vLLM expects the scale to be broadcasted across all experts.
"""
regular_weights = []
expert_scale_weights = []
updated_params = set()
for name, weight in weights:
# Check if this is an expert scale parameter that needs broadcasting
if ("feed_forward.experts." in name and "scale" in name
and ".shared_expert" not in name):
if name in params_dict:
param = params_dict[name]
if (hasattr(param, 'data') and param.data.numel() > 1
and weight.numel() == 1):
# Broadcast single value to all experts
param.data.fill_(weight.item())
updated_params.add(name)
continue
expert_scale_weights.append((name, weight))
else:
regular_weights.append((name, weight))
return regular_weights, expert_scale_weights, updated_params
def _load_other_weights(self, other_weights: Iterable[tuple[str,
torch.Tensor]],
params_dict: dict,
stacked_params_mapping: list) -> set[str]:
"""Load non-language-model weights with stacking support."""
updated_params = set()
if self.use_data_parallel:
other_weights = self._consolidate_qkv_weights(other_weights)
for name, loaded_weight in other_weights:
# Try stacked parameter mapping first
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name or self.use_data_parallel:
continue
@@ -938,10 +1016,56 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader(param, loaded_weight, shard_id)
break
else:
# Use regular weight loading
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
# Shared expert gate_up_proj stacking
(".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
(".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
# Feed forward gate_up_proj stacking (for non-MoE layers if any)
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
# Separate and rename weights
language_model_weights, other_weights = (
self._separate_and_rename_weights(weights))
# Handle expert scale parameters
regular_weights, expert_scale_weights, updated_params_from_experts = (
self._handle_expert_scale_broadcasting(language_model_weights,
params_dict))
updated_params.update(updated_params_from_experts)
loader = AutoWeightsLoader(self)
loaded_language_model_params = loader.load_weights(regular_weights)
assert loaded_language_model_params is not None
updated_params.update(loaded_language_model_params)
if expert_scale_weights:
loaded_expert_scale_params = loader.load_weights(
expert_scale_weights)
if loaded_expert_scale_params:
updated_params.update(loaded_expert_scale_params)
updated_params.update(
self._load_other_weights(other_weights, params_dict,
stacked_params_mapping))
return updated_params