[Bugfix][Model] Fix FP8 k_scale/v_scale not loaded for Qwen3-MoE (#35656)
Signed-off-by: raghavan <oneraghavan@gmail.com>
This commit is contained in:
@@ -535,10 +535,6 @@ class Qwen3MoeModel(nn.Module):
|
||||
ignore_suffixes = (
|
||||
".bias",
|
||||
"_bias",
|
||||
".k_scale",
|
||||
"_k_scale",
|
||||
".v_scale",
|
||||
"_v_scale",
|
||||
".weight_scale",
|
||||
"_weight_scale",
|
||||
".input_scale",
|
||||
@@ -562,6 +558,10 @@ class Qwen3MoeModel(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name or "zero_point" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
@@ -654,20 +654,8 @@ class Qwen3MoeModel(nn.Module):
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
".kv_scale", ".attn.kv_scale"
|
||||
)
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
|
||||
name,
|
||||
remapped_kv_scale_name,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
|
||||
@@ -172,10 +172,6 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
ignore_suffixes = (
|
||||
".bias",
|
||||
"_bias",
|
||||
".k_scale",
|
||||
"_k_scale",
|
||||
".v_scale",
|
||||
"_v_scale",
|
||||
".weight_scale",
|
||||
"_weight_scale",
|
||||
".input_scale",
|
||||
@@ -191,6 +187,11 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
]
|
||||
num_experts = self.config.num_experts
|
||||
for name, loaded_weight in weights:
|
||||
if "scale" in name or "zero_point" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
|
||||
is_fused_expert = True
|
||||
@@ -305,20 +306,8 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
".kv_scale", ".attn.kv_scale"
|
||||
)
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
|
||||
name,
|
||||
remapped_kv_scale_name,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
|
||||
Reference in New Issue
Block a user