[Misc] Clean up Kimi-audio whisper encoder loading (#36903)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -52,6 +52,9 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
revision: str | None
|
||||
"""The optional model revision."""
|
||||
|
||||
subfolder: str | None = None
|
||||
"""The subfolder inside the model repo."""
|
||||
|
||||
prefix: str = ""
|
||||
"""A prefix to prepend to all weights."""
|
||||
|
||||
@@ -81,6 +84,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
def _prepare_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
subfolder: str | None,
|
||||
revision: str | None,
|
||||
fall_back_to_pt: bool,
|
||||
allow_patterns_overrides: list[str] | None,
|
||||
@@ -143,11 +147,15 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
subfolder=subfolder,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
if subfolder is not None:
|
||||
hf_folder = os.path.join(hf_folder, subfolder)
|
||||
|
||||
hf_weights_files: list[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
@@ -166,8 +174,9 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
cache_dir=self.load_config.download_dir,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file
|
||||
@@ -189,6 +198,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
extra_config = self.load_config.model_loader_extra_config
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path,
|
||||
source.subfolder,
|
||||
source.revision,
|
||||
source.fall_back_to_pt,
|
||||
source.allow_patterns_overrides,
|
||||
@@ -269,8 +279,9 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
model_name_or_path=model_config.model,
|
||||
subfolder=None,
|
||||
revision=model_config.revision,
|
||||
fall_back_to_pt=True,
|
||||
allow_patterns_overrides=None,
|
||||
)
|
||||
|
||||
@@ -472,6 +472,7 @@ def download_weights_from_hf(
|
||||
cache_dir: str | None,
|
||||
allow_patterns: list[str],
|
||||
revision: str | None = None,
|
||||
subfolder: str | None = None,
|
||||
ignore_patterns: str | list[str] | None = None,
|
||||
) -> str:
|
||||
"""Download model weights from Hugging Face Hub.
|
||||
@@ -484,6 +485,8 @@ def download_weights_from_hf(
|
||||
weight files. Files matched by any of the patterns will be
|
||||
downloaded.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
subfolder (Optional[str]): The subfolder within the model repository
|
||||
to download weights from.
|
||||
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
|
||||
filter out the weight files. Files matched by any of the patterns
|
||||
will be ignored.
|
||||
@@ -498,7 +501,11 @@ def download_weights_from_hf(
|
||||
# so we only have to call snapshot_download once.
|
||||
try:
|
||||
fs = HfFileSystem()
|
||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||
file_list = fs.ls(
|
||||
os.path.join(model_name_or_path, subfolder or ""),
|
||||
detail=False,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# If downloading safetensors and an index file exists, use the
|
||||
# specific file names from the index to avoid downloading
|
||||
@@ -510,6 +517,7 @@ def download_weights_from_hf(
|
||||
filename=SAFE_WEIGHTS_INDEX_NAME,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
with open(index_path) as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
@@ -570,6 +578,7 @@ def download_safetensors_index_file_from_hf(
|
||||
model_name_or_path: str,
|
||||
index_file: str,
|
||||
cache_dir: str | None,
|
||||
subfolder: str | None = None,
|
||||
revision: str | None = None,
|
||||
) -> None:
|
||||
"""Download hf safetensors index file from Hugging Face Hub.
|
||||
@@ -579,6 +588,8 @@ def download_safetensors_index_file_from_hf(
|
||||
index_file (str): The safetensors index file name
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
subfolder (Optional[str]): The subfolder within the model repository
|
||||
to download weights from.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
"""
|
||||
# Use file lock to prevent multiple processes from
|
||||
@@ -591,6 +602,7 @@ def download_safetensors_index_file_from_hf(
|
||||
filename=index_file,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
)
|
||||
# If file not found on remote or locally, we should not fail since
|
||||
|
||||
@@ -3,15 +3,12 @@
|
||||
|
||||
"""Inference-only Kimi-Audio model compatible with HuggingFace weights."""
|
||||
|
||||
import os
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
from transformers import BatchFeature
|
||||
from transformers import WhisperConfig as HFWhisperConfig
|
||||
|
||||
@@ -19,9 +16,8 @@ from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
)
|
||||
from vllm.model_executor.model_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
@@ -64,15 +60,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
|
||||
|
||||
|
||||
def _get_whisper_local_path(repo_id: str):
|
||||
if os.path.exists(repo_id):
|
||||
repo_local_path = repo_id
|
||||
else:
|
||||
repo_local_path = snapshot_download(repo_id, local_files_only=True)
|
||||
|
||||
return os.path.join(repo_local_path, KIMIA_WHISPER_SUBFOLDER)
|
||||
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute output lengths after Whisper feature extraction.
|
||||
|
||||
@@ -93,7 +80,6 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
|
||||
# packed_modules_mapping for Q/K/V fusion during weight loading
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"kv_proj": ["k_proj", "v_proj"],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -104,19 +90,49 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
|
||||
model_path = vllm_config.model_config.model
|
||||
|
||||
# Load WhisperConfig from the subfolder
|
||||
whisper_dir = _get_whisper_local_path(model_path)
|
||||
whisper_config = HFWhisperConfig.from_pretrained(whisper_dir)
|
||||
|
||||
# Temporarily replace hf_config for WhisperEncoder.__init__()
|
||||
original_config = vllm_config.model_config.hf_config
|
||||
vllm_config.model_config.hf_config = whisper_config
|
||||
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, prefix=prefix, init_in_fp32=init_in_fp32
|
||||
whisper_config = HFWhisperConfig.from_pretrained(
|
||||
model_path,
|
||||
subfolder=KIMIA_WHISPER_SUBFOLDER,
|
||||
)
|
||||
|
||||
# Restore original config
|
||||
vllm_config.model_config.hf_config = original_config
|
||||
super().__init__(
|
||||
vllm_config=vllm_config.with_hf_config(whisper_config),
|
||||
prefix=prefix,
|
||||
init_in_fp32=init_in_fp32,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -374,6 +390,8 @@ class KimiAudioForConditionalGeneration(
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# audio tower
|
||||
"model.encoder.": "audio_tower.",
|
||||
# Audio projector (VQ-Adaptor)
|
||||
"model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.",
|
||||
"model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.",
|
||||
@@ -384,7 +402,11 @@ class KimiAudioForConditionalGeneration(
|
||||
"model.embed_tokens.": "language_model.model.embed_tokens.",
|
||||
"model.norm.": "language_model.model.norm.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
}
|
||||
},
|
||||
orig_to_new_substr={
|
||||
".fc1.": ".mlp.fc1.",
|
||||
".fc2.": ".mlp.fc2.",
|
||||
},
|
||||
)
|
||||
|
||||
# Audio placeholder token sequence
|
||||
@@ -401,6 +423,14 @@ class KimiAudioForConditionalGeneration(
|
||||
self.multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.model_path = vllm_config.model_config.model
|
||||
|
||||
self.secondary_weights = [
|
||||
DefaultModelLoader.Source(
|
||||
model_or_path=vllm_config.model_config.model,
|
||||
subfolder="whisper-large-v3",
|
||||
revision=None,
|
||||
)
|
||||
]
|
||||
|
||||
self.audio_tower = KimiAudioWhisperEncoder(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||
@@ -577,99 +607,19 @@ class KimiAudioForConditionalGeneration(
|
||||
"""Load weights, skipping MIMO layers (TTS-only) for ASR."""
|
||||
# Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
|
||||
skipped_patterns = [
|
||||
# Audio tower
|
||||
"model.",
|
||||
# MIMO/TTS
|
||||
"mimo_layers.",
|
||||
"mimo_output.",
|
||||
"mimo_norm.",
|
||||
"audio_decoder.",
|
||||
]
|
||||
|
||||
# Filter weights
|
||||
filtered_weights = [
|
||||
(name, param)
|
||||
for name, param in weights
|
||||
if not any(pattern in name for pattern in skipped_patterns)
|
||||
]
|
||||
|
||||
# Separate main weights (non-Whisper) from Whisper weights
|
||||
main_weights = [
|
||||
(name, param)
|
||||
for name, param in filtered_weights
|
||||
if not name.startswith("audio_tower.")
|
||||
]
|
||||
|
||||
# Load main model weights (LLM + projector) with mapper
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
# Load Whisper encoder weights from subfolder
|
||||
whisper_dir = _get_whisper_local_path(self.model_path)
|
||||
whisper_path = os.path.join(whisper_dir, "model.safetensors")
|
||||
if os.path.exists(whisper_path):
|
||||
whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
|
||||
loaded.update(whisper_loaded)
|
||||
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skipped_patterns)
|
||||
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
return loaded
|
||||
|
||||
def _load_whisper_weights_from_file(self, whisper_path: str) -> set[str]:
|
||||
"""Load Whisper encoder weights from safetensors file with transformations."""
|
||||
if not os.path.exists(whisper_path):
|
||||
return set()
|
||||
|
||||
# Step 1: Load raw weights from safetensors file
|
||||
whisper_weights = []
|
||||
with safe_open(whisper_path, framework="pt") as f:
|
||||
for key in f.keys(): # noqa: SIM118
|
||||
if key.startswith("model.encoder.") and "embed_positions" not in key:
|
||||
new_key = key.replace("model.encoder.", "")
|
||||
whisper_weights.append((new_key, f.get_tensor(key)))
|
||||
|
||||
# Step 2: Apply fc → mlp mapping using WeightsMapper
|
||||
fc_mapper = WeightsMapper(
|
||||
orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
|
||||
)
|
||||
whisper_mapped = list(fc_mapper.apply(whisper_weights))
|
||||
|
||||
# Step 3: Apply Q/K/V fusion manually
|
||||
stacked_params_mapping = [
|
||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.audio_tower.named_parameters())
|
||||
whisper_loaded: set[str] = set()
|
||||
|
||||
for name, loaded_weight in whisper_mapped:
|
||||
fused = False
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
fused_name = name.replace(weight_name, param_name)
|
||||
if fused_name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[fused_name]
|
||||
param.weight_loader(param, loaded_weight, shard_id)
|
||||
whisper_loaded.add(f"audio_tower.{fused_name}")
|
||||
fused = True
|
||||
break
|
||||
|
||||
if not fused:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
whisper_loaded.add(f"audio_tower.{name}")
|
||||
|
||||
# Add embed_positions which is initialized randomly
|
||||
whisper_loaded.add("audio_tower.embed_positions.weight")
|
||||
|
||||
return whisper_loaded
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: str
|
||||
|
||||
Reference in New Issue
Block a user