[Misc] Clean up Kimi-audio whisper encoder loading (#36903)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -52,6 +52,9 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
revision: str | None
|
revision: str | None
|
||||||
"""The optional model revision."""
|
"""The optional model revision."""
|
||||||
|
|
||||||
|
subfolder: str | None = None
|
||||||
|
"""The subfolder inside the model repo."""
|
||||||
|
|
||||||
prefix: str = ""
|
prefix: str = ""
|
||||||
"""A prefix to prepend to all weights."""
|
"""A prefix to prepend to all weights."""
|
||||||
|
|
||||||
@@ -81,6 +84,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
def _prepare_weights(
|
def _prepare_weights(
|
||||||
self,
|
self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
|
subfolder: str | None,
|
||||||
revision: str | None,
|
revision: str | None,
|
||||||
fall_back_to_pt: bool,
|
fall_back_to_pt: bool,
|
||||||
allow_patterns_overrides: list[str] | None,
|
allow_patterns_overrides: list[str] | None,
|
||||||
@@ -143,11 +147,15 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
self.load_config.download_dir,
|
self.load_config.download_dir,
|
||||||
allow_patterns,
|
allow_patterns,
|
||||||
revision,
|
revision,
|
||||||
|
subfolder=subfolder,
|
||||||
ignore_patterns=self.load_config.ignore_patterns,
|
ignore_patterns=self.load_config.ignore_patterns,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
|
|
||||||
|
if subfolder is not None:
|
||||||
|
hf_folder = os.path.join(hf_folder, subfolder)
|
||||||
|
|
||||||
hf_weights_files: list[str] = []
|
hf_weights_files: list[str] = []
|
||||||
for pattern in allow_patterns:
|
for pattern in allow_patterns:
|
||||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||||
@@ -166,8 +174,9 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
download_safetensors_index_file_from_hf(
|
download_safetensors_index_file_from_hf(
|
||||||
model_name_or_path,
|
model_name_or_path,
|
||||||
index_file,
|
index_file,
|
||||||
self.load_config.download_dir,
|
cache_dir=self.load_config.download_dir,
|
||||||
revision,
|
subfolder=subfolder,
|
||||||
|
revision=revision,
|
||||||
)
|
)
|
||||||
hf_weights_files = filter_duplicate_safetensors_files(
|
hf_weights_files = filter_duplicate_safetensors_files(
|
||||||
hf_weights_files, hf_folder, index_file
|
hf_weights_files, hf_folder, index_file
|
||||||
@@ -189,6 +198,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
extra_config = self.load_config.model_loader_extra_config
|
extra_config = self.load_config.model_loader_extra_config
|
||||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||||
source.model_or_path,
|
source.model_or_path,
|
||||||
|
source.subfolder,
|
||||||
source.revision,
|
source.revision,
|
||||||
source.fall_back_to_pt,
|
source.fall_back_to_pt,
|
||||||
source.allow_patterns_overrides,
|
source.allow_patterns_overrides,
|
||||||
@@ -269,8 +279,9 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
def download_model(self, model_config: ModelConfig) -> None:
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
self._prepare_weights(
|
self._prepare_weights(
|
||||||
model_config.model,
|
model_name_or_path=model_config.model,
|
||||||
model_config.revision,
|
subfolder=None,
|
||||||
|
revision=model_config.revision,
|
||||||
fall_back_to_pt=True,
|
fall_back_to_pt=True,
|
||||||
allow_patterns_overrides=None,
|
allow_patterns_overrides=None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -472,6 +472,7 @@ def download_weights_from_hf(
|
|||||||
cache_dir: str | None,
|
cache_dir: str | None,
|
||||||
allow_patterns: list[str],
|
allow_patterns: list[str],
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
|
subfolder: str | None = None,
|
||||||
ignore_patterns: str | list[str] | None = None,
|
ignore_patterns: str | list[str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Download model weights from Hugging Face Hub.
|
"""Download model weights from Hugging Face Hub.
|
||||||
@@ -484,6 +485,8 @@ def download_weights_from_hf(
|
|||||||
weight files. Files matched by any of the patterns will be
|
weight files. Files matched by any of the patterns will be
|
||||||
downloaded.
|
downloaded.
|
||||||
revision (Optional[str]): The revision of the model.
|
revision (Optional[str]): The revision of the model.
|
||||||
|
subfolder (Optional[str]): The subfolder within the model repository
|
||||||
|
to download weights from.
|
||||||
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
|
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
|
||||||
filter out the weight files. Files matched by any of the patterns
|
filter out the weight files. Files matched by any of the patterns
|
||||||
will be ignored.
|
will be ignored.
|
||||||
@@ -498,7 +501,11 @@ def download_weights_from_hf(
|
|||||||
# so we only have to call snapshot_download once.
|
# so we only have to call snapshot_download once.
|
||||||
try:
|
try:
|
||||||
fs = HfFileSystem()
|
fs = HfFileSystem()
|
||||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
file_list = fs.ls(
|
||||||
|
os.path.join(model_name_or_path, subfolder or ""),
|
||||||
|
detail=False,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
|
||||||
# If downloading safetensors and an index file exists, use the
|
# If downloading safetensors and an index file exists, use the
|
||||||
# specific file names from the index to avoid downloading
|
# specific file names from the index to avoid downloading
|
||||||
@@ -510,6 +517,7 @@ def download_weights_from_hf(
|
|||||||
filename=SAFE_WEIGHTS_INDEX_NAME,
|
filename=SAFE_WEIGHTS_INDEX_NAME,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
)
|
)
|
||||||
with open(index_path) as f:
|
with open(index_path) as f:
|
||||||
weight_map = json.load(f)["weight_map"]
|
weight_map = json.load(f)["weight_map"]
|
||||||
@@ -570,6 +578,7 @@ def download_safetensors_index_file_from_hf(
|
|||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
index_file: str,
|
index_file: str,
|
||||||
cache_dir: str | None,
|
cache_dir: str | None,
|
||||||
|
subfolder: str | None = None,
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Download hf safetensors index file from Hugging Face Hub.
|
"""Download hf safetensors index file from Hugging Face Hub.
|
||||||
@@ -579,6 +588,8 @@ def download_safetensors_index_file_from_hf(
|
|||||||
index_file (str): The safetensors index file name
|
index_file (str): The safetensors index file name
|
||||||
cache_dir (Optional[str]): The cache directory to store the model
|
cache_dir (Optional[str]): The cache directory to store the model
|
||||||
weights. If None, will use HF defaults.
|
weights. If None, will use HF defaults.
|
||||||
|
subfolder (Optional[str]): The subfolder within the model repository
|
||||||
|
to download weights from.
|
||||||
revision (Optional[str]): The revision of the model.
|
revision (Optional[str]): The revision of the model.
|
||||||
"""
|
"""
|
||||||
# Use file lock to prevent multiple processes from
|
# Use file lock to prevent multiple processes from
|
||||||
@@ -591,6 +602,7 @@ def download_safetensors_index_file_from_hf(
|
|||||||
filename=index_file,
|
filename=index_file,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
)
|
)
|
||||||
# If file not found on remote or locally, we should not fail since
|
# If file not found on remote or locally, we should not fail since
|
||||||
|
|||||||
@@ -3,15 +3,12 @@
|
|||||||
|
|
||||||
"""Inference-only Kimi-Audio model compatible with HuggingFace weights."""
|
"""Inference-only Kimi-Audio model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
import os
|
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Any, ClassVar, Literal
|
from typing import Any, ClassVar, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from safetensors import safe_open
|
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
from transformers import WhisperConfig as HFWhisperConfig
|
from transformers import WhisperConfig as HFWhisperConfig
|
||||||
|
|
||||||
@@ -19,9 +16,8 @@ from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
|||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.inputs.data import PromptType, TokensPrompt
|
from vllm.inputs.data import PromptType, TokensPrompt
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader import DefaultModelLoader
|
||||||
default_weight_loader,
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
)
|
|
||||||
from vllm.model_executor.models.interfaces import (
|
from vllm.model_executor.models.interfaces import (
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
@@ -64,15 +60,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
|||||||
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
|
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
|
||||||
|
|
||||||
|
|
||||||
def _get_whisper_local_path(repo_id: str):
|
|
||||||
if os.path.exists(repo_id):
|
|
||||||
repo_local_path = repo_id
|
|
||||||
else:
|
|
||||||
repo_local_path = snapshot_download(repo_id, local_files_only=True)
|
|
||||||
|
|
||||||
return os.path.join(repo_local_path, KIMIA_WHISPER_SUBFOLDER)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
|
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
|
||||||
"""Compute output lengths after Whisper feature extraction.
|
"""Compute output lengths after Whisper feature extraction.
|
||||||
|
|
||||||
@@ -93,7 +80,6 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
|
|||||||
# packed_modules_mapping for Q/K/V fusion during weight loading
|
# packed_modules_mapping for Q/K/V fusion during weight loading
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
"kv_proj": ["k_proj", "v_proj"],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -104,19 +90,49 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
|
|||||||
model_path = vllm_config.model_config.model
|
model_path = vllm_config.model_config.model
|
||||||
|
|
||||||
# Load WhisperConfig from the subfolder
|
# Load WhisperConfig from the subfolder
|
||||||
whisper_dir = _get_whisper_local_path(model_path)
|
whisper_config = HFWhisperConfig.from_pretrained(
|
||||||
whisper_config = HFWhisperConfig.from_pretrained(whisper_dir)
|
model_path,
|
||||||
|
subfolder=KIMIA_WHISPER_SUBFOLDER,
|
||||||
# Temporarily replace hf_config for WhisperEncoder.__init__()
|
|
||||||
original_config = vllm_config.model_config.hf_config
|
|
||||||
vllm_config.model_config.hf_config = whisper_config
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
vllm_config=vllm_config, prefix=prefix, init_in_fp32=init_in_fp32
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Restore original config
|
super().__init__(
|
||||||
vllm_config.model_config.hf_config = original_config
|
vllm_config=vllm_config.with_hf_config(whisper_config),
|
||||||
|
prefix=prefix,
|
||||||
|
init_in_fp32=init_in_fp32,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@@ -374,6 +390,8 @@ class KimiAudioForConditionalGeneration(
|
|||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
|
# audio tower
|
||||||
|
"model.encoder.": "audio_tower.",
|
||||||
# Audio projector (VQ-Adaptor)
|
# Audio projector (VQ-Adaptor)
|
||||||
"model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.",
|
"model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.",
|
||||||
"model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.",
|
"model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.",
|
||||||
@@ -384,7 +402,11 @@ class KimiAudioForConditionalGeneration(
|
|||||||
"model.embed_tokens.": "language_model.model.embed_tokens.",
|
"model.embed_tokens.": "language_model.model.embed_tokens.",
|
||||||
"model.norm.": "language_model.model.norm.",
|
"model.norm.": "language_model.model.norm.",
|
||||||
"lm_head.": "language_model.lm_head.",
|
"lm_head.": "language_model.lm_head.",
|
||||||
}
|
},
|
||||||
|
orig_to_new_substr={
|
||||||
|
".fc1.": ".mlp.fc1.",
|
||||||
|
".fc2.": ".mlp.fc2.",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Audio placeholder token sequence
|
# Audio placeholder token sequence
|
||||||
@@ -401,6 +423,14 @@ class KimiAudioForConditionalGeneration(
|
|||||||
self.multimodal_config = vllm_config.model_config.multimodal_config
|
self.multimodal_config = vllm_config.model_config.multimodal_config
|
||||||
self.model_path = vllm_config.model_config.model
|
self.model_path = vllm_config.model_config.model
|
||||||
|
|
||||||
|
self.secondary_weights = [
|
||||||
|
DefaultModelLoader.Source(
|
||||||
|
model_or_path=vllm_config.model_config.model,
|
||||||
|
subfolder="whisper-large-v3",
|
||||||
|
revision=None,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
self.audio_tower = KimiAudioWhisperEncoder(
|
self.audio_tower = KimiAudioWhisperEncoder(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "audio_tower"),
|
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||||
@@ -577,99 +607,19 @@ class KimiAudioForConditionalGeneration(
|
|||||||
"""Load weights, skipping MIMO layers (TTS-only) for ASR."""
|
"""Load weights, skipping MIMO layers (TTS-only) for ASR."""
|
||||||
# Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
|
# Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
|
||||||
skipped_patterns = [
|
skipped_patterns = [
|
||||||
|
# Audio tower
|
||||||
|
"model.",
|
||||||
|
# MIMO/TTS
|
||||||
"mimo_layers.",
|
"mimo_layers.",
|
||||||
"mimo_output.",
|
"mimo_output.",
|
||||||
"mimo_norm.",
|
"mimo_norm.",
|
||||||
"audio_decoder.",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Filter weights
|
|
||||||
filtered_weights = [
|
|
||||||
(name, param)
|
|
||||||
for name, param in weights
|
|
||||||
if not any(pattern in name for pattern in skipped_patterns)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Separate main weights (non-Whisper) from Whisper weights
|
|
||||||
main_weights = [
|
|
||||||
(name, param)
|
|
||||||
for name, param in filtered_weights
|
|
||||||
if not name.startswith("audio_tower.")
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Load main model weights (LLM + projector) with mapper
|
# Load main model weights (LLM + projector) with mapper
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self, skip_prefixes=skipped_patterns)
|
||||||
loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper)
|
loaded = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
# Load Whisper encoder weights from subfolder
|
|
||||||
whisper_dir = _get_whisper_local_path(self.model_path)
|
|
||||||
whisper_path = os.path.join(whisper_dir, "model.safetensors")
|
|
||||||
if os.path.exists(whisper_path):
|
|
||||||
whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
|
|
||||||
loaded.update(whisper_loaded)
|
|
||||||
|
|
||||||
return loaded
|
return loaded
|
||||||
|
|
||||||
def _load_whisper_weights_from_file(self, whisper_path: str) -> set[str]:
|
|
||||||
"""Load Whisper encoder weights from safetensors file with transformations."""
|
|
||||||
if not os.path.exists(whisper_path):
|
|
||||||
return set()
|
|
||||||
|
|
||||||
# Step 1: Load raw weights from safetensors file
|
|
||||||
whisper_weights = []
|
|
||||||
with safe_open(whisper_path, framework="pt") as f:
|
|
||||||
for key in f.keys(): # noqa: SIM118
|
|
||||||
if key.startswith("model.encoder.") and "embed_positions" not in key:
|
|
||||||
new_key = key.replace("model.encoder.", "")
|
|
||||||
whisper_weights.append((new_key, f.get_tensor(key)))
|
|
||||||
|
|
||||||
# Step 2: Apply fc → mlp mapping using WeightsMapper
|
|
||||||
fc_mapper = WeightsMapper(
|
|
||||||
orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
|
|
||||||
)
|
|
||||||
whisper_mapped = list(fc_mapper.apply(whisper_weights))
|
|
||||||
|
|
||||||
# Step 3: Apply Q/K/V fusion manually
|
|
||||||
stacked_params_mapping = [
|
|
||||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
|
||||||
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
|
||||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
|
||||||
]
|
|
||||||
|
|
||||||
params_dict = dict(self.audio_tower.named_parameters())
|
|
||||||
whisper_loaded: set[str] = set()
|
|
||||||
|
|
||||||
for name, loaded_weight in whisper_mapped:
|
|
||||||
fused = False
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
fused_name = name.replace(weight_name, param_name)
|
|
||||||
if fused_name not in params_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[fused_name]
|
|
||||||
param.weight_loader(param, loaded_weight, shard_id)
|
|
||||||
whisper_loaded.add(f"audio_tower.{fused_name}")
|
|
||||||
fused = True
|
|
||||||
break
|
|
||||||
|
|
||||||
if not fused:
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if name not in params_dict:
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
whisper_loaded.add(f"audio_tower.{name}")
|
|
||||||
|
|
||||||
# Add embed_positions which is initialized randomly
|
|
||||||
whisper_loaded.add("audio_tower.embed_positions.weight")
|
|
||||||
|
|
||||||
return whisper_loaded
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_speech_to_text_config(
|
def get_speech_to_text_config(
|
||||||
cls, model_config: ModelConfig, task_type: str
|
cls, model_config: ModelConfig, task_type: str
|
||||||
|
|||||||
Reference in New Issue
Block a user