From a8e8d62dd80f53444ae62191fa0bd3901a02c7e7 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Sat, 14 Mar 2026 23:37:52 +0800 Subject: [PATCH] [Misc] Clean up Kimi-audio whisper encoder loading (#36903) Signed-off-by: Isotr0py --- .../model_loader/default_loader.py | 19 +- .../model_loader/weight_utils.py | 14 +- vllm/model_executor/models/kimi_audio.py | 176 +++++++----------- 3 files changed, 91 insertions(+), 118 deletions(-) diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 7064998af..1235792b8 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -52,6 +52,9 @@ class DefaultModelLoader(BaseModelLoader): revision: str | None """The optional model revision.""" + subfolder: str | None = None + """The subfolder inside the model repo.""" + prefix: str = "" """A prefix to prepend to all weights.""" @@ -81,6 +84,7 @@ class DefaultModelLoader(BaseModelLoader): def _prepare_weights( self, model_name_or_path: str, + subfolder: str | None, revision: str | None, fall_back_to_pt: bool, allow_patterns_overrides: list[str] | None, @@ -143,11 +147,15 @@ class DefaultModelLoader(BaseModelLoader): self.load_config.download_dir, allow_patterns, revision, + subfolder=subfolder, ignore_patterns=self.load_config.ignore_patterns, ) else: hf_folder = model_name_or_path + if subfolder is not None: + hf_folder = os.path.join(hf_folder, subfolder) + hf_weights_files: list[str] = [] for pattern in allow_patterns: hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) @@ -166,8 +174,9 @@ class DefaultModelLoader(BaseModelLoader): download_safetensors_index_file_from_hf( model_name_or_path, index_file, - self.load_config.download_dir, - revision, + cache_dir=self.load_config.download_dir, + subfolder=subfolder, + revision=revision, ) hf_weights_files = filter_duplicate_safetensors_files( hf_weights_files, hf_folder, index_file @@ -189,6 +198,7 @@ class DefaultModelLoader(BaseModelLoader): extra_config = self.load_config.model_loader_extra_config hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, + source.subfolder, source.revision, source.fall_back_to_pt, source.allow_patterns_overrides, @@ -269,8 +279,9 @@ class DefaultModelLoader(BaseModelLoader): def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights( - model_config.model, - model_config.revision, + model_name_or_path=model_config.model, + subfolder=None, + revision=model_config.revision, fall_back_to_pt=True, allow_patterns_overrides=None, ) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index e00a17a15..e7a34ca63 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -472,6 +472,7 @@ def download_weights_from_hf( cache_dir: str | None, allow_patterns: list[str], revision: str | None = None, + subfolder: str | None = None, ignore_patterns: str | list[str] | None = None, ) -> str: """Download model weights from Hugging Face Hub. @@ -484,6 +485,8 @@ def download_weights_from_hf( weight files. Files matched by any of the patterns will be downloaded. revision (Optional[str]): The revision of the model. + subfolder (Optional[str]): The subfolder within the model repository + to download weights from. ignore_patterns (Optional[Union[str, list[str]]]): The patterns to filter out the weight files. Files matched by any of the patterns will be ignored. @@ -498,7 +501,11 @@ def download_weights_from_hf( # so we only have to call snapshot_download once. try: fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + file_list = fs.ls( + os.path.join(model_name_or_path, subfolder or ""), + detail=False, + revision=revision, + ) # If downloading safetensors and an index file exists, use the # specific file names from the index to avoid downloading @@ -510,6 +517,7 @@ def download_weights_from_hf( filename=SAFE_WEIGHTS_INDEX_NAME, cache_dir=cache_dir, revision=revision, + subfolder=subfolder, ) with open(index_path) as f: weight_map = json.load(f)["weight_map"] @@ -570,6 +578,7 @@ def download_safetensors_index_file_from_hf( model_name_or_path: str, index_file: str, cache_dir: str | None, + subfolder: str | None = None, revision: str | None = None, ) -> None: """Download hf safetensors index file from Hugging Face Hub. @@ -579,6 +588,8 @@ def download_safetensors_index_file_from_hf( index_file (str): The safetensors index file name cache_dir (Optional[str]): The cache directory to store the model weights. If None, will use HF defaults. + subfolder (Optional[str]): The subfolder within the model repository + to download weights from. revision (Optional[str]): The revision of the model. """ # Use file lock to prevent multiple processes from @@ -591,6 +602,7 @@ def download_safetensors_index_file_from_hf( filename=index_file, cache_dir=cache_dir, revision=revision, + subfolder=subfolder, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, ) # If file not found on remote or locally, we should not fail since diff --git a/vllm/model_executor/models/kimi_audio.py b/vllm/model_executor/models/kimi_audio.py index 6f15a4388..36d22d867 100644 --- a/vllm/model_executor/models/kimi_audio.py +++ b/vllm/model_executor/models/kimi_audio.py @@ -3,15 +3,12 @@ """Inference-only Kimi-Audio model compatible with HuggingFace weights.""" -import os from collections.abc import Iterable, Mapping, Sequence from typing import Any, ClassVar, Literal import numpy as np import torch import torch.nn as nn -from huggingface_hub import snapshot_download -from safetensors import safe_open from transformers import BatchFeature from transformers import WhisperConfig as HFWhisperConfig @@ -19,9 +16,8 @@ from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.inputs.data import PromptType, TokensPrompt from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, -) +from vllm.model_executor.model_loader import DefaultModelLoader +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import ( SupportsMultiModal, SupportsPP, @@ -64,15 +60,6 @@ from vllm.v1.sample.metadata import SamplingMetadata KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3" -def _get_whisper_local_path(repo_id: str): - if os.path.exists(repo_id): - repo_local_path = repo_id - else: - repo_local_path = snapshot_download(repo_id, local_files_only=True) - - return os.path.join(repo_local_path, KIMIA_WHISPER_SUBFOLDER) - - def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor: """Compute output lengths after Whisper feature extraction. @@ -93,7 +80,6 @@ class KimiAudioWhisperEncoder(WhisperEncoder): # packed_modules_mapping for Q/K/V fusion during weight loading packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "kv_proj": ["k_proj", "v_proj"], } def __init__( @@ -104,19 +90,49 @@ class KimiAudioWhisperEncoder(WhisperEncoder): model_path = vllm_config.model_config.model # Load WhisperConfig from the subfolder - whisper_dir = _get_whisper_local_path(model_path) - whisper_config = HFWhisperConfig.from_pretrained(whisper_dir) - - # Temporarily replace hf_config for WhisperEncoder.__init__() - original_config = vllm_config.model_config.hf_config - vllm_config.model_config.hf_config = whisper_config - - super().__init__( - vllm_config=vllm_config, prefix=prefix, init_in_fp32=init_in_fp32 + whisper_config = HFWhisperConfig.from_pretrained( + model_path, + subfolder=KIMIA_WHISPER_SUBFOLDER, ) - # Restore original config - vllm_config.model_config.hf_config = original_config + super().__init__( + vllm_config=vllm_config.with_hf_config(whisper_config), + prefix=prefix, + init_in_fp32=init_in_fp32, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params # ----------------------------------------------------------------------------- @@ -374,6 +390,8 @@ class KimiAudioForConditionalGeneration( hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ + # audio tower + "model.encoder.": "audio_tower.", # Audio projector (VQ-Adaptor) "model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.", "model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.", @@ -384,7 +402,11 @@ class KimiAudioForConditionalGeneration( "model.embed_tokens.": "language_model.model.embed_tokens.", "model.norm.": "language_model.model.norm.", "lm_head.": "language_model.lm_head.", - } + }, + orig_to_new_substr={ + ".fc1.": ".mlp.fc1.", + ".fc2.": ".mlp.fc2.", + }, ) # Audio placeholder token sequence @@ -401,6 +423,14 @@ class KimiAudioForConditionalGeneration( self.multimodal_config = vllm_config.model_config.multimodal_config self.model_path = vllm_config.model_config.model + self.secondary_weights = [ + DefaultModelLoader.Source( + model_or_path=vllm_config.model_config.model, + subfolder="whisper-large-v3", + revision=None, + ) + ] + self.audio_tower = KimiAudioWhisperEncoder( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "audio_tower"), @@ -577,99 +607,19 @@ class KimiAudioForConditionalGeneration( """Load weights, skipping MIMO layers (TTS-only) for ASR.""" # Filter out MIMO/TTS weights since we only do ASR (speech-to-text) skipped_patterns = [ + # Audio tower + "model.", + # MIMO/TTS "mimo_layers.", "mimo_output.", "mimo_norm.", - "audio_decoder.", - ] - - # Filter weights - filtered_weights = [ - (name, param) - for name, param in weights - if not any(pattern in name for pattern in skipped_patterns) - ] - - # Separate main weights (non-Whisper) from Whisper weights - main_weights = [ - (name, param) - for name, param in filtered_weights - if not name.startswith("audio_tower.") ] # Load main model weights (LLM + projector) with mapper - loader = AutoWeightsLoader(self) - loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper) - - # Load Whisper encoder weights from subfolder - whisper_dir = _get_whisper_local_path(self.model_path) - whisper_path = os.path.join(whisper_dir, "model.safetensors") - if os.path.exists(whisper_path): - whisper_loaded = self._load_whisper_weights_from_file(whisper_path) - loaded.update(whisper_loaded) - + loader = AutoWeightsLoader(self, skip_prefixes=skipped_patterns) + loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loaded - def _load_whisper_weights_from_file(self, whisper_path: str) -> set[str]: - """Load Whisper encoder weights from safetensors file with transformations.""" - if not os.path.exists(whisper_path): - return set() - - # Step 1: Load raw weights from safetensors file - whisper_weights = [] - with safe_open(whisper_path, framework="pt") as f: - for key in f.keys(): # noqa: SIM118 - if key.startswith("model.encoder.") and "embed_positions" not in key: - new_key = key.replace("model.encoder.", "") - whisper_weights.append((new_key, f.get_tensor(key))) - - # Step 2: Apply fc → mlp mapping using WeightsMapper - fc_mapper = WeightsMapper( - orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."} - ) - whisper_mapped = list(fc_mapper.apply(whisper_weights)) - - # Step 3: Apply Q/K/V fusion manually - stacked_params_mapping = [ - (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), - (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), - (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), - ] - - params_dict = dict(self.audio_tower.named_parameters()) - whisper_loaded: set[str] = set() - - for name, loaded_weight in whisper_mapped: - fused = False - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - fused_name = name.replace(weight_name, param_name) - if fused_name not in params_dict: - continue - - param = params_dict[fused_name] - param.weight_loader(param, loaded_weight, shard_id) - whisper_loaded.add(f"audio_tower.{fused_name}") - fused = True - break - - if not fused: - if name.endswith(".bias") and name not in params_dict: - continue - if name not in params_dict: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - whisper_loaded.add(f"audio_tower.{name}") - - # Add embed_positions which is initialized randomly - whisper_loaded.add("audio_tower.embed_positions.weight") - - return whisper_loaded - @classmethod def get_speech_to_text_config( cls, model_config: ModelConfig, task_type: str