[Misc] Clean up Kimi-audio whisper encoder loading (#36903)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-03-14 23:37:52 +08:00
committed by GitHub
parent e42b49bd69
commit a8e8d62dd8
3 changed files with 91 additions and 118 deletions

View File

@@ -52,6 +52,9 @@ class DefaultModelLoader(BaseModelLoader):
revision: str | None revision: str | None
"""The optional model revision.""" """The optional model revision."""
subfolder: str | None = None
"""The subfolder inside the model repo."""
prefix: str = "" prefix: str = ""
"""A prefix to prepend to all weights.""" """A prefix to prepend to all weights."""
@@ -81,6 +84,7 @@ class DefaultModelLoader(BaseModelLoader):
def _prepare_weights( def _prepare_weights(
self, self,
model_name_or_path: str, model_name_or_path: str,
subfolder: str | None,
revision: str | None, revision: str | None,
fall_back_to_pt: bool, fall_back_to_pt: bool,
allow_patterns_overrides: list[str] | None, allow_patterns_overrides: list[str] | None,
@@ -143,11 +147,15 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.download_dir, self.load_config.download_dir,
allow_patterns, allow_patterns,
revision, revision,
subfolder=subfolder,
ignore_patterns=self.load_config.ignore_patterns, ignore_patterns=self.load_config.ignore_patterns,
) )
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
if subfolder is not None:
hf_folder = os.path.join(hf_folder, subfolder)
hf_weights_files: list[str] = [] hf_weights_files: list[str] = []
for pattern in allow_patterns: for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
@@ -166,8 +174,9 @@ class DefaultModelLoader(BaseModelLoader):
download_safetensors_index_file_from_hf( download_safetensors_index_file_from_hf(
model_name_or_path, model_name_or_path,
index_file, index_file,
self.load_config.download_dir, cache_dir=self.load_config.download_dir,
revision, subfolder=subfolder,
revision=revision,
) )
hf_weights_files = filter_duplicate_safetensors_files( hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file hf_weights_files, hf_folder, index_file
@@ -189,6 +198,7 @@ class DefaultModelLoader(BaseModelLoader):
extra_config = self.load_config.model_loader_extra_config extra_config = self.load_config.model_loader_extra_config
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.model_or_path,
source.subfolder,
source.revision, source.revision,
source.fall_back_to_pt, source.fall_back_to_pt,
source.allow_patterns_overrides, source.allow_patterns_overrides,
@@ -269,8 +279,9 @@ class DefaultModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights( self._prepare_weights(
model_config.model, model_name_or_path=model_config.model,
model_config.revision, subfolder=None,
revision=model_config.revision,
fall_back_to_pt=True, fall_back_to_pt=True,
allow_patterns_overrides=None, allow_patterns_overrides=None,
) )

View File

@@ -472,6 +472,7 @@ def download_weights_from_hf(
cache_dir: str | None, cache_dir: str | None,
allow_patterns: list[str], allow_patterns: list[str],
revision: str | None = None, revision: str | None = None,
subfolder: str | None = None,
ignore_patterns: str | list[str] | None = None, ignore_patterns: str | list[str] | None = None,
) -> str: ) -> str:
"""Download model weights from Hugging Face Hub. """Download model weights from Hugging Face Hub.
@@ -484,6 +485,8 @@ def download_weights_from_hf(
weight files. Files matched by any of the patterns will be weight files. Files matched by any of the patterns will be
downloaded. downloaded.
revision (Optional[str]): The revision of the model. revision (Optional[str]): The revision of the model.
subfolder (Optional[str]): The subfolder within the model repository
to download weights from.
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns filter out the weight files. Files matched by any of the patterns
will be ignored. will be ignored.
@@ -498,7 +501,11 @@ def download_weights_from_hf(
# so we only have to call snapshot_download once. # so we only have to call snapshot_download once.
try: try:
fs = HfFileSystem() fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision) file_list = fs.ls(
os.path.join(model_name_or_path, subfolder or ""),
detail=False,
revision=revision,
)
# If downloading safetensors and an index file exists, use the # If downloading safetensors and an index file exists, use the
# specific file names from the index to avoid downloading # specific file names from the index to avoid downloading
@@ -510,6 +517,7 @@ def download_weights_from_hf(
filename=SAFE_WEIGHTS_INDEX_NAME, filename=SAFE_WEIGHTS_INDEX_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
revision=revision, revision=revision,
subfolder=subfolder,
) )
with open(index_path) as f: with open(index_path) as f:
weight_map = json.load(f)["weight_map"] weight_map = json.load(f)["weight_map"]
@@ -570,6 +578,7 @@ def download_safetensors_index_file_from_hf(
model_name_or_path: str, model_name_or_path: str,
index_file: str, index_file: str,
cache_dir: str | None, cache_dir: str | None,
subfolder: str | None = None,
revision: str | None = None, revision: str | None = None,
) -> None: ) -> None:
"""Download hf safetensors index file from Hugging Face Hub. """Download hf safetensors index file from Hugging Face Hub.
@@ -579,6 +588,8 @@ def download_safetensors_index_file_from_hf(
index_file (str): The safetensors index file name index_file (str): The safetensors index file name
cache_dir (Optional[str]): The cache directory to store the model cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults. weights. If None, will use HF defaults.
subfolder (Optional[str]): The subfolder within the model repository
to download weights from.
revision (Optional[str]): The revision of the model. revision (Optional[str]): The revision of the model.
""" """
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
@@ -591,6 +602,7 @@ def download_safetensors_index_file_from_hf(
filename=index_file, filename=index_file,
cache_dir=cache_dir, cache_dir=cache_dir,
revision=revision, revision=revision,
subfolder=subfolder,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
) )
# If file not found on remote or locally, we should not fail since # If file not found on remote or locally, we should not fail since

View File

@@ -3,15 +3,12 @@
"""Inference-only Kimi-Audio model compatible with HuggingFace weights.""" """Inference-only Kimi-Audio model compatible with HuggingFace weights."""
import os
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, ClassVar, Literal from typing import Any, ClassVar, Literal
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import BatchFeature from transformers import BatchFeature
from transformers import WhisperConfig as HFWhisperConfig from transformers import WhisperConfig as HFWhisperConfig
@@ -19,9 +16,8 @@ from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader import DefaultModelLoader
default_weight_loader, from vllm.model_executor.model_loader.weight_utils import default_weight_loader
)
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
@@ -64,15 +60,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3" KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
def _get_whisper_local_path(repo_id: str):
if os.path.exists(repo_id):
repo_local_path = repo_id
else:
repo_local_path = snapshot_download(repo_id, local_files_only=True)
return os.path.join(repo_local_path, KIMIA_WHISPER_SUBFOLDER)
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor: def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
"""Compute output lengths after Whisper feature extraction. """Compute output lengths after Whisper feature extraction.
@@ -93,7 +80,6 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
# packed_modules_mapping for Q/K/V fusion during weight loading # packed_modules_mapping for Q/K/V fusion during weight loading
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"kv_proj": ["k_proj", "v_proj"],
} }
def __init__( def __init__(
@@ -104,19 +90,49 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
model_path = vllm_config.model_config.model model_path = vllm_config.model_config.model
# Load WhisperConfig from the subfolder # Load WhisperConfig from the subfolder
whisper_dir = _get_whisper_local_path(model_path) whisper_config = HFWhisperConfig.from_pretrained(
whisper_config = HFWhisperConfig.from_pretrained(whisper_dir) model_path,
subfolder=KIMIA_WHISPER_SUBFOLDER,
# Temporarily replace hf_config for WhisperEncoder.__init__()
original_config = vllm_config.model_config.hf_config
vllm_config.model_config.hf_config = whisper_config
super().__init__(
vllm_config=vllm_config, prefix=prefix, init_in_fp32=init_in_fp32
) )
# Restore original config super().__init__(
vllm_config.model_config.hf_config = original_config vllm_config=vllm_config.with_hf_config(whisper_config),
prefix=prefix,
init_in_fp32=init_in_fp32,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -374,6 +390,8 @@ class KimiAudioForConditionalGeneration(
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={ orig_to_new_prefix={
# audio tower
"model.encoder.": "audio_tower.",
# Audio projector (VQ-Adaptor) # Audio projector (VQ-Adaptor)
"model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.", "model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.",
"model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.", "model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.",
@@ -384,7 +402,11 @@ class KimiAudioForConditionalGeneration(
"model.embed_tokens.": "language_model.model.embed_tokens.", "model.embed_tokens.": "language_model.model.embed_tokens.",
"model.norm.": "language_model.model.norm.", "model.norm.": "language_model.model.norm.",
"lm_head.": "language_model.lm_head.", "lm_head.": "language_model.lm_head.",
} },
orig_to_new_substr={
".fc1.": ".mlp.fc1.",
".fc2.": ".mlp.fc2.",
},
) )
# Audio placeholder token sequence # Audio placeholder token sequence
@@ -401,6 +423,14 @@ class KimiAudioForConditionalGeneration(
self.multimodal_config = vllm_config.model_config.multimodal_config self.multimodal_config = vllm_config.model_config.multimodal_config
self.model_path = vllm_config.model_config.model self.model_path = vllm_config.model_config.model
self.secondary_weights = [
DefaultModelLoader.Source(
model_or_path=vllm_config.model_config.model,
subfolder="whisper-large-v3",
revision=None,
)
]
self.audio_tower = KimiAudioWhisperEncoder( self.audio_tower = KimiAudioWhisperEncoder(
vllm_config=vllm_config, vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "audio_tower"), prefix=maybe_prefix(prefix, "audio_tower"),
@@ -577,99 +607,19 @@ class KimiAudioForConditionalGeneration(
"""Load weights, skipping MIMO layers (TTS-only) for ASR.""" """Load weights, skipping MIMO layers (TTS-only) for ASR."""
# Filter out MIMO/TTS weights since we only do ASR (speech-to-text) # Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
skipped_patterns = [ skipped_patterns = [
# Audio tower
"model.",
# MIMO/TTS
"mimo_layers.", "mimo_layers.",
"mimo_output.", "mimo_output.",
"mimo_norm.", "mimo_norm.",
"audio_decoder.",
]
# Filter weights
filtered_weights = [
(name, param)
for name, param in weights
if not any(pattern in name for pattern in skipped_patterns)
]
# Separate main weights (non-Whisper) from Whisper weights
main_weights = [
(name, param)
for name, param in filtered_weights
if not name.startswith("audio_tower.")
] ]
# Load main model weights (LLM + projector) with mapper # Load main model weights (LLM + projector) with mapper
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self, skip_prefixes=skipped_patterns)
loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper) loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
# Load Whisper encoder weights from subfolder
whisper_dir = _get_whisper_local_path(self.model_path)
whisper_path = os.path.join(whisper_dir, "model.safetensors")
if os.path.exists(whisper_path):
whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
loaded.update(whisper_loaded)
return loaded return loaded
def _load_whisper_weights_from_file(self, whisper_path: str) -> set[str]:
"""Load Whisper encoder weights from safetensors file with transformations."""
if not os.path.exists(whisper_path):
return set()
# Step 1: Load raw weights from safetensors file
whisper_weights = []
with safe_open(whisper_path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
if key.startswith("model.encoder.") and "embed_positions" not in key:
new_key = key.replace("model.encoder.", "")
whisper_weights.append((new_key, f.get_tensor(key)))
# Step 2: Apply fc → mlp mapping using WeightsMapper
fc_mapper = WeightsMapper(
orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
)
whisper_mapped = list(fc_mapper.apply(whisper_weights))
# Step 3: Apply Q/K/V fusion manually
stacked_params_mapping = [
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
]
params_dict = dict(self.audio_tower.named_parameters())
whisper_loaded: set[str] = set()
for name, loaded_weight in whisper_mapped:
fused = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
fused_name = name.replace(weight_name, param_name)
if fused_name not in params_dict:
continue
param = params_dict[fused_name]
param.weight_loader(param, loaded_weight, shard_id)
whisper_loaded.add(f"audio_tower.{fused_name}")
fused = True
break
if not fused:
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
whisper_loaded.add(f"audio_tower.{name}")
# Add embed_positions which is initialized randomly
whisper_loaded.add("audio_tower.embed_positions.weight")
return whisper_loaded
@classmethod @classmethod
def get_speech_to_text_config( def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str cls, model_config: ModelConfig, task_type: str