diff --git a/vllm/model_executor/models/kimi_k25.py b/vllm/model_executor/models/kimi_k25.py index 24c89d222..4be79ca95 100644 --- a/vllm/model_executor/models/kimi_k25.py +++ b/vllm/model_executor/models/kimi_k25.py @@ -23,16 +23,7 @@ from transformers.processing_utils import ProcessorMixin from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.distributed import get_pp_group from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import SharedFusedMoE -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - maybe_remap_kv_scale_name, -) -from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP from vllm.model_executor.models.kimi_k25_vit import ( KimiK25MultiModalProjector, @@ -64,7 +55,12 @@ from vllm.transformers_utils.configs import KimiK25Config from vllm.transformers_utils.processor import cached_get_image_processor from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) logger = init_logger(__name__) @@ -294,6 +290,13 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) supports_encoder_tp_data = True + weights_mapper = WeightsMapper( + orig_to_new_prefix={ + "mm_projector.proj.0": "mm_projector.linear_1", + "mm_projector.proj.2": "mm_projector.linear_2", + } + ) + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: # Kimi-K2.5 uses video_chunk for all media types @@ -323,45 +326,39 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) self.hidden_size = config.text_config.hidden_size self.device = current_platform.current_device() # Build vision tower directly with KimiK25VisionConfig - self.vision_tower = MoonViT3dPretrainedModel( - config.vision_config, - prefix=maybe_prefix(prefix, "vision_tower"), - ) - self.vision_tower = self.vision_tower.to( - device=self.device, dtype=model_config.dtype - ) + with self._mark_tower_model(vllm_config, "vision_chunk"): + self.vision_tower = MoonViT3dPretrainedModel( + config.vision_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) + self.vision_tower = self.vision_tower.to( + device=self.device, dtype=model_config.dtype + ) - self.mm_projector = KimiK25MultiModalProjector( - config=config.vision_config, - use_data_parallel=self.use_data_parallel, - prefix=maybe_prefix(prefix, "mm_projector"), - ) - self.mm_projector = self.mm_projector.to( - device=self.device, dtype=model_config.dtype - ) + self.mm_projector = KimiK25MultiModalProjector( + config=config.vision_config, + use_data_parallel=self.use_data_parallel, + prefix=maybe_prefix(prefix, "mm_projector"), + ) + self.mm_projector = self.mm_projector.to( + device=self.device, dtype=model_config.dtype + ) self.quant_config = quant_config sub_vllm_config = copy.deepcopy(vllm_config) sub_vllm_config.model_config.hf_config = ( sub_vllm_config.model_config.hf_config.text_config ) - self.language_model = DeepseekV2Model( - vllm_config=sub_vllm_config, - prefix=maybe_prefix(prefix, "language_model"), - ) - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.text_config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head"), + with self._mark_language_model(vllm_config): + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["DeepseekV2ForCausalLM"], ) - else: - self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) self.media_placeholder: int = self.config.media_placeholder_token_id def _parse_and_validate_media_input( @@ -421,9 +418,6 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) vision_embeddings = self._process_media_input(media_input) return vision_embeddings - def get_language_model(self) -> torch.nn.Module: - return self.language_model - def forward( self, input_ids: torch.Tensor, @@ -444,139 +438,9 @@ class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, **kwargs) + logits = self.language_model.compute_logits(hidden_states) return logits - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - config = self.config.text_config - if not getattr(config, "n_routed_experts", None): - return [] - return SharedFusedMoE.make_expert_params_mapping( - self, - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=config.n_routed_experts, - num_redundant_experts=0, - ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - config = self.config.text_config - _KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", - # mm_projector -> mm_projector mapping - # "mm_projector": "mm_projector", - "mm_projector.proj.0": "mm_projector.linear_1", - "mm_projector.proj.2": "mm_projector.linear_2", - } - stacked_params_mapping = [ - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] - if getattr(config, "kv_lora_rank", None) and getattr( - config, "q_lora_rank", None - ): - stacked_params_mapping += [ - (".fused_qkv_a_proj", ".q_a_proj", 0), - (".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1), - ] - expert_params_mapping = self.get_expert_mapping() - - params_dict = dict(self.named_parameters()) - - for args in weights: - name, loaded_weight = args[:2] - kwargs = args[2] if len(args) > 2 else {} - if "rotary_emb.inv_freq" in name: - continue - - spec_layer = get_spec_layer_idx_from_weight_name(config, name) - if spec_layer is not None: - continue # skip spec decode layers for main model - - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - continue - - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - - use_default_weight_loading = False - if "vision" in name: - if self.vision_tower is not None: - use_default_weight_loading = True - else: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - if ("mlp.experts." in name) and name not in params_dict: - continue - name = name.replace(weight_name, param_name) - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id, **kwargs) - break - else: - for _, ( - param_name, - weight_name, - expert_id, - shard_id, - ) in enumerate(expert_params_mapping): - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - name, - expert_id=expert_id, - shard_id=shard_id, - **kwargs, - ) - break - else: - use_default_weight_loading = True - - if use_default_weight_loading: - if name.endswith(".bias") and name not in params_dict: - continue - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight, **kwargs) - - -def get_spec_layer_idx_from_weight_name( - config: KimiK25Config, weight_name: str -) -> int | None: - if hasattr(config, "num_nextn_predict_layers") and ( - config.num_nextn_predict_layers > 0 - ): - layer_idx = config.num_hidden_layers - for i in range(config.num_nextn_predict_layers): - # might start with language_model.model.layers - if f"model.layers.{layer_idx + i}." in weight_name: - return layer_idx + i - return None + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.weights_mapper) diff --git a/vllm/model_executor/models/kimi_k25_vit.py b/vllm/model_executor/models/kimi_k25_vit.py index 650ff7d21..470311ecc 100644 --- a/vllm/model_executor/models/kimi_k25_vit.py +++ b/vllm/model_executor/models/kimi_k25_vit.py @@ -660,13 +660,13 @@ class KimiK25MultiModalProjector(nn.Module): self.hidden_size, self.hidden_size, bias=True, - prefix=maybe_prefix(prefix, "linear_1"), + prefix=f"{prefix}.linear_1", ) self.linear_2 = ReplicatedLinear( self.hidden_size, config.mm_hidden_size, bias=True, - prefix=maybe_prefix(prefix, "linear_2"), + prefix=f"{prefix}.linear_2", ) self.act = GELUActivation()