[Misc] Clean up Kimi-audio whisper encoder loading (#36903)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -3,15 +3,12 @@
|
||||
|
||||
"""Inference-only Kimi-Audio model compatible with HuggingFace weights."""
|
||||
|
||||
import os
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
from transformers import BatchFeature
|
||||
from transformers import WhisperConfig as HFWhisperConfig
|
||||
|
||||
@@ -19,9 +16,8 @@ from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
)
|
||||
from vllm.model_executor.model_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
@@ -64,15 +60,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
|
||||
|
||||
|
||||
def _get_whisper_local_path(repo_id: str):
|
||||
if os.path.exists(repo_id):
|
||||
repo_local_path = repo_id
|
||||
else:
|
||||
repo_local_path = snapshot_download(repo_id, local_files_only=True)
|
||||
|
||||
return os.path.join(repo_local_path, KIMIA_WHISPER_SUBFOLDER)
|
||||
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute output lengths after Whisper feature extraction.
|
||||
|
||||
@@ -93,7 +80,6 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
|
||||
# packed_modules_mapping for Q/K/V fusion during weight loading
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"kv_proj": ["k_proj", "v_proj"],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -104,19 +90,49 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
|
||||
model_path = vllm_config.model_config.model
|
||||
|
||||
# Load WhisperConfig from the subfolder
|
||||
whisper_dir = _get_whisper_local_path(model_path)
|
||||
whisper_config = HFWhisperConfig.from_pretrained(whisper_dir)
|
||||
|
||||
# Temporarily replace hf_config for WhisperEncoder.__init__()
|
||||
original_config = vllm_config.model_config.hf_config
|
||||
vllm_config.model_config.hf_config = whisper_config
|
||||
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, prefix=prefix, init_in_fp32=init_in_fp32
|
||||
whisper_config = HFWhisperConfig.from_pretrained(
|
||||
model_path,
|
||||
subfolder=KIMIA_WHISPER_SUBFOLDER,
|
||||
)
|
||||
|
||||
# Restore original config
|
||||
vllm_config.model_config.hf_config = original_config
|
||||
super().__init__(
|
||||
vllm_config=vllm_config.with_hf_config(whisper_config),
|
||||
prefix=prefix,
|
||||
init_in_fp32=init_in_fp32,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -374,6 +390,8 @@ class KimiAudioForConditionalGeneration(
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# audio tower
|
||||
"model.encoder.": "audio_tower.",
|
||||
# Audio projector (VQ-Adaptor)
|
||||
"model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.",
|
||||
"model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.",
|
||||
@@ -384,7 +402,11 @@ class KimiAudioForConditionalGeneration(
|
||||
"model.embed_tokens.": "language_model.model.embed_tokens.",
|
||||
"model.norm.": "language_model.model.norm.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
}
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".fc1.": ".mlp.fc1.",
|
||||
".fc2.": ".mlp.fc2.",
|
||||
},
|
||||
)
|
||||
|
||||
# Audio placeholder token sequence
|
||||
@@ -401,6 +423,14 @@ class KimiAudioForConditionalGeneration(
|
||||
self.multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.model_path = vllm_config.model_config.model
|
||||
|
||||
self.secondary_weights = [
|
||||
DefaultModelLoader.Source(
|
||||
model_or_path=vllm_config.model_config.model,
|
||||
subfolder="whisper-large-v3",
|
||||
revision=None,
|
||||
)
|
||||
]
|
||||
|
||||
self.audio_tower = KimiAudioWhisperEncoder(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||
@@ -577,99 +607,19 @@ class KimiAudioForConditionalGeneration(
|
||||
"""Load weights, skipping MIMO layers (TTS-only) for ASR."""
|
||||
# Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
|
||||
skipped_patterns = [
|
||||
# Audio tower
|
||||
"model.",
|
||||
# MIMO/TTS
|
||||
"mimo_layers.",
|
||||
"mimo_output.",
|
||||
"mimo_norm.",
|
||||
"audio_decoder.",
|
||||
]
|
||||
|
||||
# Filter weights
|
||||
filtered_weights = [
|
||||
(name, param)
|
||||
for name, param in weights
|
||||
if not any(pattern in name for pattern in skipped_patterns)
|
||||
]
|
||||
|
||||
# Separate main weights (non-Whisper) from Whisper weights
|
||||
main_weights = [
|
||||
(name, param)
|
||||
for name, param in filtered_weights
|
||||
if not name.startswith("audio_tower.")
|
||||
]
|
||||
|
||||
# Load main model weights (LLM + projector) with mapper
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
# Load Whisper encoder weights from subfolder
|
||||
whisper_dir = _get_whisper_local_path(self.model_path)
|
||||
whisper_path = os.path.join(whisper_dir, "model.safetensors")
|
||||
if os.path.exists(whisper_path):
|
||||
whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
|
||||
loaded.update(whisper_loaded)
|
||||
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skipped_patterns)
|
||||
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
return loaded
|
||||
|
||||
def _load_whisper_weights_from_file(self, whisper_path: str) -> set[str]:
|
||||
"""Load Whisper encoder weights from safetensors file with transformations."""
|
||||
if not os.path.exists(whisper_path):
|
||||
return set()
|
||||
|
||||
# Step 1: Load raw weights from safetensors file
|
||||
whisper_weights = []
|
||||
with safe_open(whisper_path, framework="pt") as f:
|
||||
for key in f.keys(): # noqa: SIM118
|
||||
if key.startswith("model.encoder.") and "embed_positions" not in key:
|
||||
new_key = key.replace("model.encoder.", "")
|
||||
whisper_weights.append((new_key, f.get_tensor(key)))
|
||||
|
||||
# Step 2: Apply fc → mlp mapping using WeightsMapper
|
||||
fc_mapper = WeightsMapper(
|
||||
orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
|
||||
)
|
||||
whisper_mapped = list(fc_mapper.apply(whisper_weights))
|
||||
|
||||
# Step 3: Apply Q/K/V fusion manually
|
||||
stacked_params_mapping = [
|
||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.audio_tower.named_parameters())
|
||||
whisper_loaded: set[str] = set()
|
||||
|
||||
for name, loaded_weight in whisper_mapped:
|
||||
fused = False
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
fused_name = name.replace(weight_name, param_name)
|
||||
if fused_name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[fused_name]
|
||||
param.weight_loader(param, loaded_weight, shard_id)
|
||||
whisper_loaded.add(f"audio_tower.{fused_name}")
|
||||
fused = True
|
||||
break
|
||||
|
||||
if not fused:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
whisper_loaded.add(f"audio_tower.{name}")
|
||||
|
||||
# Add embed_positions which is initialized randomly
|
||||
whisper_loaded.add("audio_tower.embed_positions.weight")
|
||||
|
||||
return whisper_loaded
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: str
|
||||
|
||||
Reference in New Issue
Block a user