[Models] Refactor Kimi-K2.5 weight loading (#33346)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-01-30 13:31:20 +08:00
committed by GitHub
parent ec51831a22
commit 8bfc8d5600
2 changed files with 40 additions and 176 deletions

View File

@@ -23,16 +23,7 @@ from transformers.processing_utils import ProcessorMixin
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
from vllm.model_executor.models.kimi_k25_vit import (
KimiK25MultiModalProjector,
@@ -64,7 +55,12 @@ from vllm.transformers_utils.configs import KimiK25Config
from vllm.transformers_utils.processor import cached_get_image_processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
from .utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
logger = init_logger(__name__)
@@ -294,6 +290,13 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
supports_encoder_tp_data = True
weights_mapper = WeightsMapper(
orig_to_new_prefix={
"mm_projector.proj.0": "mm_projector.linear_1",
"mm_projector.proj.2": "mm_projector.linear_2",
}
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
# Kimi-K2.5 uses video_chunk for all media types
@@ -323,45 +326,39 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self.hidden_size = config.text_config.hidden_size
self.device = current_platform.current_device()
# Build vision tower directly with KimiK25VisionConfig
self.vision_tower = MoonViT3dPretrainedModel(
config.vision_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.vision_tower = self.vision_tower.to(
device=self.device, dtype=model_config.dtype
)
with self._mark_tower_model(vllm_config, "vision_chunk"):
self.vision_tower = MoonViT3dPretrainedModel(
config.vision_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.vision_tower = self.vision_tower.to(
device=self.device, dtype=model_config.dtype
)
self.mm_projector = KimiK25MultiModalProjector(
config=config.vision_config,
use_data_parallel=self.use_data_parallel,
prefix=maybe_prefix(prefix, "mm_projector"),
)
self.mm_projector = self.mm_projector.to(
device=self.device, dtype=model_config.dtype
)
self.mm_projector = KimiK25MultiModalProjector(
config=config.vision_config,
use_data_parallel=self.use_data_parallel,
prefix=maybe_prefix(prefix, "mm_projector"),
)
self.mm_projector = self.mm_projector.to(
device=self.device, dtype=model_config.dtype
)
self.quant_config = quant_config
sub_vllm_config = copy.deepcopy(vllm_config)
sub_vllm_config.model_config.hf_config = (
sub_vllm_config.model_config.hf_config.text_config
)
self.language_model = DeepseekV2Model(
vllm_config=sub_vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.text_config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["DeepseekV2ForCausalLM"],
)
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
self.media_placeholder: int = self.config.media_placeholder_token_id
def _parse_and_validate_media_input(
@@ -421,9 +418,6 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
vision_embeddings = self._process_media_input(media_input)
return vision_embeddings
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def forward(
self,
input_ids: torch.Tensor,
@@ -444,139 +438,9 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states, **kwargs)
logits = self.language_model.compute_logits(hidden_states)
return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
config = self.config.text_config
if not getattr(config, "n_routed_experts", None):
return []
return SharedFusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=config.n_routed_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
config = self.config.text_config
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
# mm_projector -> mm_projector mapping
# "mm_projector": "mm_projector",
"mm_projector.proj.0": "mm_projector.linear_1",
"mm_projector.proj.2": "mm_projector.linear_2",
}
stacked_params_mapping = [
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
if getattr(config, "kv_lora_rank", None) and getattr(
config, "q_lora_rank", None
):
stacked_params_mapping += [
(".fused_qkv_a_proj", ".q_a_proj", 0),
(".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1),
]
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters())
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
use_default_weight_loading = True
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id, **kwargs)
break
else:
for _, (
param_name,
weight_name,
expert_id,
shard_id,
) in enumerate(expert_params_mapping):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
expert_id=expert_id,
shard_id=shard_id,
**kwargs,
)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, **kwargs)
def get_spec_layer_idx_from_weight_name(
config: KimiK25Config, weight_name: str
) -> int | None:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
# might start with language_model.model.layers
if f"model.layers.{layer_idx + i}." in weight_name:
return layer_idx + i
return None
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.weights_mapper)

View File

@@ -660,13 +660,13 @@ class KimiK25MultiModalProjector(nn.Module):
self.hidden_size,
self.hidden_size,
bias=True,
prefix=maybe_prefix(prefix, "linear_1"),
prefix=f"{prefix}.linear_1",
)
self.linear_2 = ReplicatedLinear(
self.hidden_size,
config.mm_hidden_size,
bias=True,
prefix=maybe_prefix(prefix, "linear_2"),
prefix=f"{prefix}.linear_2",
)
self.act = GELUActivation()