[Bugfix][Model] Fix FP8 k_scale/v_scale not loaded for Qwen3-MoE (#35656)

Signed-off-by: raghavan <oneraghavan@gmail.com>
This commit is contained in:
Raghavan
2026-03-04 18:45:38 +05:30
committed by GitHub
parent bb6888b8b1
commit c8c3935b70
3 changed files with 129 additions and 36 deletions

View File

@@ -535,10 +535,6 @@ class Qwen3MoeModel(nn.Module):
ignore_suffixes = (
".bias",
"_bias",
".k_scale",
"_k_scale",
".v_scale",
"_v_scale",
".weight_scale",
"_weight_scale",
".input_scale",
@@ -562,6 +558,10 @@ class Qwen3MoeModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
@@ -654,20 +654,8 @@ class Qwen3MoeModel(nn.Module):
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale"
)
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader

View File

@@ -172,10 +172,6 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
ignore_suffixes = (
".bias",
"_bias",
".k_scale",
"_k_scale",
".v_scale",
"_v_scale",
".weight_scale",
"_weight_scale",
".input_scale",
@@ -191,6 +187,11 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
]
num_experts = self.config.num_experts
for name, loaded_weight in weights:
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
is_fused_expert = True
@@ -305,20 +306,8 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale"
)
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader