[Bugfix] Fix Kimi-K2.5 NVFP4 checkpoints weight loading (#33876)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-02-05 18:29:54 +08:00
committed by GitHub
parent 59a5cb387a
commit a2522839d8
2 changed files with 15 additions and 5 deletions

View File

@@ -1485,7 +1485,7 @@ class DeepseekV2ForCausalLM(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer: if name is not None and not is_fusion_moe_shared_experts_layer:
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params

View File

@@ -24,7 +24,11 @@ from transformers.processing_utils import ProcessorMixin
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP from vllm.model_executor.models.interfaces import (
SupportsMultiModal,
SupportsPP,
SupportsQuant,
)
from vllm.model_executor.models.kimi_k25_vit import ( from vllm.model_executor.models.kimi_k25_vit import (
KimiK25MultiModalProjector, KimiK25MultiModalProjector,
MoonViT3dPretrainedModel, MoonViT3dPretrainedModel,
@@ -302,7 +306,9 @@ class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo])
info=KimiK25ProcessingInfo, info=KimiK25ProcessingInfo,
dummy_inputs=KimiK25DummyInputsBuilder, dummy_inputs=KimiK25DummyInputsBuilder,
) )
class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class KimiK25ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant
):
"""Kimi-K2.5 model for conditional generation. """Kimi-K2.5 model for conditional generation.
Supports both image and video-chunk modalities. Supports both image and video-chunk modalities.
@@ -312,8 +318,12 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
supports_encoder_tp_data = True supports_encoder_tp_data = True
weights_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
# For legacy NVFP4 checkpoint compatibility:
# see https://github.com/vllm-project/vllm/pull/33346#issuecomment-3851475033
"language_model.layers.": "language_model.model.layers.",
# mm projector
"mm_projector.proj.0": "mm_projector.linear_1", "mm_projector.proj.0": "mm_projector.linear_1",
"mm_projector.proj.2": "mm_projector.linear_2", "mm_projector.proj.2": "mm_projector.linear_2",
} }
@@ -465,4 +475,4 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.weights_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)