[Misc] Clean up Kimi-audio whisper encoder loading (#36903)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-03-14 23:37:52 +08:00
committed by GitHub
parent e42b49bd69
commit a8e8d62dd8
3 changed files with 91 additions and 118 deletions

View File

@@ -3,15 +3,12 @@
"""Inference-only Kimi-Audio model compatible with HuggingFace weights."""
import os
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, ClassVar, Literal
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import BatchFeature
from transformers import WhisperConfig as HFWhisperConfig
@@ -19,9 +16,8 @@ from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
)
from vllm.model_executor.model_loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (
SupportsMultiModal,
SupportsPP,
@@ -64,15 +60,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
def _get_whisper_local_path(repo_id: str):
if os.path.exists(repo_id):
repo_local_path = repo_id
else:
repo_local_path = snapshot_download(repo_id, local_files_only=True)
return os.path.join(repo_local_path, KIMIA_WHISPER_SUBFOLDER)
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
"""Compute output lengths after Whisper feature extraction.
@@ -93,7 +80,6 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
# packed_modules_mapping for Q/K/V fusion during weight loading
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"kv_proj": ["k_proj", "v_proj"],
}
def __init__(
@@ -104,19 +90,49 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
model_path = vllm_config.model_config.model
# Load WhisperConfig from the subfolder
whisper_dir = _get_whisper_local_path(model_path)
whisper_config = HFWhisperConfig.from_pretrained(whisper_dir)
# Temporarily replace hf_config for WhisperEncoder.__init__()
original_config = vllm_config.model_config.hf_config
vllm_config.model_config.hf_config = whisper_config
super().__init__(
vllm_config=vllm_config, prefix=prefix, init_in_fp32=init_in_fp32
whisper_config = HFWhisperConfig.from_pretrained(
model_path,
subfolder=KIMIA_WHISPER_SUBFOLDER,
)
# Restore original config
vllm_config.model_config.hf_config = original_config
super().__init__(
vllm_config=vllm_config.with_hf_config(whisper_config),
prefix=prefix,
init_in_fp32=init_in_fp32,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
# -----------------------------------------------------------------------------
@@ -374,6 +390,8 @@ class KimiAudioForConditionalGeneration(
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# audio tower
"model.encoder.": "audio_tower.",
# Audio projector (VQ-Adaptor)
"model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.",
"model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.",
@@ -384,7 +402,11 @@ class KimiAudioForConditionalGeneration(
"model.embed_tokens.": "language_model.model.embed_tokens.",
"model.norm.": "language_model.model.norm.",
"lm_head.": "language_model.lm_head.",
}
},
orig_to_new_substr={
".fc1.": ".mlp.fc1.",
".fc2.": ".mlp.fc2.",
},
)
# Audio placeholder token sequence
@@ -401,6 +423,14 @@ class KimiAudioForConditionalGeneration(
self.multimodal_config = vllm_config.model_config.multimodal_config
self.model_path = vllm_config.model_config.model
self.secondary_weights = [
DefaultModelLoader.Source(
model_or_path=vllm_config.model_config.model,
subfolder="whisper-large-v3",
revision=None,
)
]
self.audio_tower = KimiAudioWhisperEncoder(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "audio_tower"),
@@ -577,99 +607,19 @@ class KimiAudioForConditionalGeneration(
"""Load weights, skipping MIMO layers (TTS-only) for ASR."""
# Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
skipped_patterns = [
# Audio tower
"model.",
# MIMO/TTS
"mimo_layers.",
"mimo_output.",
"mimo_norm.",
"audio_decoder.",
]
# Filter weights
filtered_weights = [
(name, param)
for name, param in weights
if not any(pattern in name for pattern in skipped_patterns)
]
# Separate main weights (non-Whisper) from Whisper weights
main_weights = [
(name, param)
for name, param in filtered_weights
if not name.startswith("audio_tower.")
]
# Load main model weights (LLM + projector) with mapper
loader = AutoWeightsLoader(self)
loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper)
# Load Whisper encoder weights from subfolder
whisper_dir = _get_whisper_local_path(self.model_path)
whisper_path = os.path.join(whisper_dir, "model.safetensors")
if os.path.exists(whisper_path):
whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
loaded.update(whisper_loaded)
loader = AutoWeightsLoader(self, skip_prefixes=skipped_patterns)
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loaded
def _load_whisper_weights_from_file(self, whisper_path: str) -> set[str]:
"""Load Whisper encoder weights from safetensors file with transformations."""
if not os.path.exists(whisper_path):
return set()
# Step 1: Load raw weights from safetensors file
whisper_weights = []
with safe_open(whisper_path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
if key.startswith("model.encoder.") and "embed_positions" not in key:
new_key = key.replace("model.encoder.", "")
whisper_weights.append((new_key, f.get_tensor(key)))
# Step 2: Apply fc → mlp mapping using WeightsMapper
fc_mapper = WeightsMapper(
orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
)
whisper_mapped = list(fc_mapper.apply(whisper_weights))
# Step 3: Apply Q/K/V fusion manually
stacked_params_mapping = [
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
]
params_dict = dict(self.audio_tower.named_parameters())
whisper_loaded: set[str] = set()
for name, loaded_weight in whisper_mapped:
fused = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
fused_name = name.replace(weight_name, param_name)
if fused_name not in params_dict:
continue
param = params_dict[fused_name]
param.weight_loader(param, loaded_weight, shard_id)
whisper_loaded.add(f"audio_tower.{fused_name}")
fused = True
break
if not fused:
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
whisper_loaded.add(f"audio_tower.{name}")
# Add embed_positions which is initialized randomly
whisper_loaded.add("audio_tower.embed_positions.weight")
return whisper_loaded
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str