[Bugfix][Model] Fix FP8 k_scale/v_scale not loaded for Qwen3-MoE (#35656)

Signed-off-by: raghavan <oneraghavan@gmail.com>
This commit is contained in:
Raghavan
2026-03-04 18:45:38 +05:30
committed by GitHub
parent bb6888b8b1
commit c8c3935b70
3 changed files with 129 additions and 36 deletions

View File

@@ -11,6 +11,7 @@ from huggingface_hub.utils import LocalEntryNotFoundError
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf,
enable_hf_transfer,
maybe_remap_kv_scale_name,
)
@@ -61,6 +62,121 @@ def test_download_weights_from_hf():
)
class TestMaybeRemapKvScaleName:
"""Tests for maybe_remap_kv_scale_name covering all checkpoint formats."""
PARAMS_DICT = {
"model.layers.0.self_attn.attn.k_scale": None,
"model.layers.0.self_attn.attn.v_scale": None,
"model.layers.0.self_attn.attn.q_scale": None,
"model.layers.0.self_attn.qkv_proj.weight": None,
}
def test_qkv_proj_k_scale(self):
"""Qwen3-MoE / llm-compressor format: qkv_proj.k_scale -> attn.k_scale
Regression test for https://github.com/vllm-project/vllm/issues/25047"""
result = maybe_remap_kv_scale_name(
"model.layers.0.self_attn.qkv_proj.k_scale", self.PARAMS_DICT
)
assert result == "model.layers.0.self_attn.attn.k_scale"
def test_qkv_proj_v_scale(self):
"""Qwen3-MoE / llm-compressor format: qkv_proj.v_scale -> attn.v_scale
Regression test for https://github.com/vllm-project/vllm/issues/25047"""
result = maybe_remap_kv_scale_name(
"model.layers.0.self_attn.qkv_proj.v_scale", self.PARAMS_DICT
)
assert result == "model.layers.0.self_attn.attn.v_scale"
def test_modelopt_k_proj_k_scale(self):
"""ModelOpt format: k_proj.k_scale -> attn.k_scale"""
result = maybe_remap_kv_scale_name(
"model.layers.0.self_attn.k_proj.k_scale", self.PARAMS_DICT
)
assert result == "model.layers.0.self_attn.attn.k_scale"
def test_modelopt_v_proj_v_scale(self):
"""ModelOpt format: v_proj.v_scale -> attn.v_scale"""
result = maybe_remap_kv_scale_name(
"model.layers.0.self_attn.v_proj.v_scale", self.PARAMS_DICT
)
assert result == "model.layers.0.self_attn.attn.v_scale"
def test_deprecated_kv_scale(self):
"""Old format: kv_scale -> attn.k_scale (deprecated)"""
result = maybe_remap_kv_scale_name(
"model.layers.0.self_attn.kv_scale", self.PARAMS_DICT
)
assert result == "model.layers.0.self_attn.attn.k_scale"
def test_default_bare_k_scale(self):
"""Default format: .k_scale -> .attn.k_scale"""
result = maybe_remap_kv_scale_name(
"model.layers.0.self_attn.k_scale", self.PARAMS_DICT
)
assert result == "model.layers.0.self_attn.attn.k_scale"
def test_non_scale_name_unchanged(self):
"""Non-scale names should be returned unchanged."""
name = "model.layers.0.self_attn.qkv_proj.weight"
result = maybe_remap_kv_scale_name(name, self.PARAMS_DICT)
assert result == name
def test_nvfp4_modelopt_k_proj_k_scale(self):
"""ModelOpt NVFP4 format (e.g. nvidia/Qwen3-30B-A3B-NVFP4):
k_proj.k_scale -> attn.k_scale.
Validates that NVFP4 checkpoints are not broken by this change."""
result = maybe_remap_kv_scale_name(
"model.layers.0.self_attn.k_proj.k_scale", self.PARAMS_DICT
)
assert result == "model.layers.0.self_attn.attn.k_scale"
def test_nvfp4_modelopt_v_proj_v_scale(self):
"""ModelOpt NVFP4 format (e.g. nvidia/Qwen3-30B-A3B-NVFP4):
v_proj.v_scale -> attn.v_scale.
Validates that NVFP4 checkpoints are not broken by this change."""
result = maybe_remap_kv_scale_name(
"model.layers.0.self_attn.v_proj.v_scale", self.PARAMS_DICT
)
assert result == "model.layers.0.self_attn.attn.v_scale"
def test_qwen3_vl_moe_qkv_proj_k_scale(self):
"""Qwen3-VL-MoE uses the same fused qkv_proj naming as Qwen3-MoE.
Regression test for qwen3_vl_moe.py fix (same bug as #25047)."""
result = maybe_remap_kv_scale_name(
"model.layers.0.self_attn.qkv_proj.k_scale", self.PARAMS_DICT
)
assert result == "model.layers.0.self_attn.attn.k_scale"
def test_qwen3_vl_moe_qkv_proj_v_scale(self):
"""Qwen3-VL-MoE uses the same fused qkv_proj naming as Qwen3-MoE.
Regression test for qwen3_vl_moe.py fix (same bug as #25047)."""
result = maybe_remap_kv_scale_name(
"model.layers.0.self_attn.qkv_proj.v_scale", self.PARAMS_DICT
)
assert result == "model.layers.0.self_attn.attn.v_scale"
def test_nvfp4_weight_scale_not_remapped(self):
"""NVFP4 weight_scale should not be touched by remap (not a kv scale)."""
name = "model.layers.0.self_attn.k_proj.weight_scale"
result = maybe_remap_kv_scale_name(name, self.PARAMS_DICT)
assert result == name
def test_nvfp4_input_scale_not_remapped(self):
"""NVFP4 input_scale should not be touched by remap (not a kv scale)."""
name = "model.layers.0.self_attn.k_proj.input_scale"
result = maybe_remap_kv_scale_name(name, self.PARAMS_DICT)
assert result == name
def test_missing_target_returns_none(self):
"""If remapped name not in params_dict, return None."""
empty_params: dict[str, None] = {}
result = maybe_remap_kv_scale_name(
"model.layers.0.self_attn.qkv_proj.k_scale", empty_params
)
assert result is None
if __name__ == "__main__":
test_hf_transfer_auto_activation()
test_download_weights_from_hf()

View File

@@ -535,10 +535,6 @@ class Qwen3MoeModel(nn.Module):
ignore_suffixes = (
".bias",
"_bias",
".k_scale",
"_k_scale",
".v_scale",
"_v_scale",
".weight_scale",
"_weight_scale",
".input_scale",
@@ -562,6 +558,10 @@ class Qwen3MoeModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
@@ -654,20 +654,8 @@ class Qwen3MoeModel(nn.Module):
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale"
)
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader

View File

@@ -172,10 +172,6 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
ignore_suffixes = (
".bias",
"_bias",
".k_scale",
"_k_scale",
".v_scale",
"_v_scale",
".weight_scale",
"_weight_scale",
".input_scale",
@@ -191,6 +187,11 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
]
num_experts = self.config.num_experts
for name, loaded_weight in weights:
if "scale" in name or "zero_point" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
is_fused_expert = True
@@ -305,20 +306,8 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale"
)
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader