[Model] Add support for moonshotai/Kimi-Audio-7B-Instruct (#36127)
Signed-off-by: tunglinwood <tunglinwood@gmail.com> Signed-off-by: tunglinwood <tomwu.tunglin@gmail.com> Signed-off-by: tunglinwood <113751333+tunglinwood@users.noreply.github.com>
This commit is contained in:
@@ -713,8 +713,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `KananaVForConditionalGeneration` | Kanana-V | T + I<sup>+</sup> | `kakaocorp/kanana-1.5-v-3b-instruct`, etc. | | ✅︎ |
|
||||
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ |
|
||||
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
|
||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
|
||||
| `KimiAudioForConditionalGeneration` | Kimi-Audio | T + A<sup>+</sup> | `moonshotai/Kimi-Audio-7B-Instruct` | | ✅︎ |
|
||||
| `KimiK25ForConditionalGeneration` | Kimi-K2.5 | T + I<sup>+</sup> | `moonshotai/Kimi-K2.5` | | ✅︎ |
|
||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
|
||||
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I<sup>+</sup> | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
|
||||
| `Lfm2VlForConditionalGeneration` | LFM2-VL | T + I<sup>+</sup> | `LiquidAI/LFM2-VL-450M`, `LiquidAI/LFM2-VL-3B`, `LiquidAI/LFM2-VL-8B-A1B`, etc. | ✅︎ | ✅︎ |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
@@ -201,6 +201,34 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Kimi-Audio-7B-Instruct
|
||||
def run_kimi_audio(question: str, audio_count: int) -> ModelRequestData:
|
||||
"""Kimi-Audio-7B-Instruct for audio transcription and understanding."""
|
||||
model_name = "moonshotai/Kimi-Audio-7B-Instruct"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
)
|
||||
|
||||
# Kimi-Audio uses <|im_kimia_text_blank|> as placeholder for audio features
|
||||
audio_placeholder = "<|im_kimia_text_blank|>" * audio_count
|
||||
# Default prompt for transcription
|
||||
if not question:
|
||||
question = "Please transcribe the audio"
|
||||
prompt = f"{audio_placeholder}{question}"
|
||||
|
||||
# Stop at EOS token (151644) to prevent repetition
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
stop_token_ids=[151644],
|
||||
)
|
||||
|
||||
|
||||
# MiDashengLM
|
||||
def run_midashenglm(question: str, audio_count: int):
|
||||
model_name = "mispeech/midashenglm-7b"
|
||||
@@ -485,6 +513,7 @@ model_example_map = {
|
||||
"glmasr": run_glmasr,
|
||||
"funaudiochat": run_funaudiochat,
|
||||
"granite_speech": run_granite_speech,
|
||||
"kimi_audio": run_kimi_audio,
|
||||
"midashenglm": run_midashenglm,
|
||||
"minicpmo": run_minicpmo,
|
||||
"phi4_mm": run_phi4mm,
|
||||
|
||||
@@ -198,13 +198,17 @@ def get_text_token_prompts(
|
||||
mm_counts,
|
||||
mm_options={},
|
||||
)
|
||||
assert isinstance(inputs.prompt, str)
|
||||
|
||||
text_prompt = inputs.prompt
|
||||
token_prompt = tokenizer.encode(
|
||||
text_prompt,
|
||||
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type, True),
|
||||
)
|
||||
# Some models (e.g., Kimi-Audio) return token IDs directly instead of str
|
||||
if isinstance(inputs.prompt, list):
|
||||
text_prompt = None
|
||||
token_prompt = inputs.prompt
|
||||
else:
|
||||
assert isinstance(inputs.prompt, str)
|
||||
text_prompt = inputs.prompt
|
||||
token_prompt = tokenizer.encode(
|
||||
text_prompt,
|
||||
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type, True),
|
||||
)
|
||||
|
||||
return text_prompt, token_prompt
|
||||
|
||||
|
||||
@@ -857,6 +857,15 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"Kwai-Keye/Keye-VL-1_5-8B",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"MoonshotKimiaForCausalLM": _HfExamplesInfo(
|
||||
"moonshotai/Kimi-Audio-7B-Instruct",
|
||||
tokenizer_mode="kimi_audio",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"KimiK25ForConditionalGeneration": _HfExamplesInfo(
|
||||
"moonshotai/Kimi-K2.5",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"KimiVLForConditionalGeneration": _HfExamplesInfo(
|
||||
"moonshotai/Kimi-VL-A3B-Instruct",
|
||||
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"},
|
||||
@@ -870,10 +879,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
)
|
||||
},
|
||||
),
|
||||
"KimiK25ForConditionalGeneration": _HfExamplesInfo(
|
||||
"moonshotai/Kimi-K2.5",
|
||||
trust_remote_code=True,
|
||||
),
|
||||
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
|
||||
"lightonai/LightOnOCR-1B-1025"
|
||||
),
|
||||
|
||||
@@ -103,6 +103,12 @@ def can_initialize(
|
||||
"pickle error when loading `transformers.models.auto.CONFIG_MAPPING`"
|
||||
)
|
||||
|
||||
if model_arch == "MoonshotKimiaForCausalLM":
|
||||
pytest.skip(
|
||||
"Kimi-Audio requires SpeechToTextConfig "
|
||||
"which is not configured in test environment"
|
||||
)
|
||||
|
||||
if model_arch in ["DeepseekV32ForCausalLM", "GlmMoeDsaForCausalLM"]:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
725
vllm/model_executor/models/kimi_audio.py
Normal file
725
vllm/model_executor/models/kimi_audio.py
Normal file
@@ -0,0 +1,725 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Inference-only Kimi-Audio model compatible with HuggingFace weights."""
|
||||
|
||||
import os
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors import safe_open
|
||||
from transformers import BatchFeature
|
||||
from transformers import WhisperConfig as HFWhisperConfig
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsTranscription,
|
||||
)
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from vllm.model_executor.models.whisper import WhisperEncoder
|
||||
from vllm.model_executor.models.whisper_utils import ISO639_1_SUPPORTED_LANGS
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||
from vllm.multimodal.parse import (
|
||||
AudioItem,
|
||||
DictEmbeddingItems,
|
||||
ModalityData,
|
||||
ModalityDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
)
|
||||
from vllm.multimodal.processing.processor import BaseMultiModalProcessor
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
|
||||
from vllm.transformers_utils.processor import cached_feature_extractor_from_config
|
||||
from vllm.transformers_utils.processors.kimi_audio import KimiAudioProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
# Kimi-Audio constants
|
||||
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
|
||||
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute output lengths after Whisper feature extraction.
|
||||
|
||||
Whisper processes audio through multiple conv layers with stride=2,
|
||||
producing 13 output features per 100 input samples.
|
||||
"""
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = (
|
||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
)
|
||||
return output_lengths
|
||||
|
||||
|
||||
class KimiAudioWhisperEncoder(WhisperEncoder):
|
||||
"""WhisperEncoder for Kimi-Audio with packed_modules_mapping."""
|
||||
|
||||
# packed_modules_mapping for Q/K/V fusion during weight loading
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"kv_proj": ["k_proj", "v_proj"],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
|
||||
):
|
||||
# Load Whisper config from subfolder (authoritative source)
|
||||
# Kimi-Audio stores Whisper config in whisper-large-v3/config.json
|
||||
model_path = vllm_config.model_config.model
|
||||
whisper_config_path = os.path.join(model_path, KIMIA_WHISPER_SUBFOLDER)
|
||||
|
||||
# Load WhisperConfig from the subfolder
|
||||
whisper_config = HFWhisperConfig.from_pretrained(whisper_config_path)
|
||||
|
||||
# Temporarily replace hf_config for WhisperEncoder.__init__()
|
||||
original_config = vllm_config.model_config.hf_config
|
||||
vllm_config.model_config.hf_config = whisper_config
|
||||
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, prefix=prefix, init_in_fp32=init_in_fp32
|
||||
)
|
||||
|
||||
# Restore original config
|
||||
vllm_config.model_config.hf_config = original_config
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Processing Info, Dummy Inputs, and MultiModal Processor
|
||||
# (Following Qwen3ASR pattern - same file as model)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class KimiAudioProcessingInfo(BaseProcessingInfo):
|
||||
"""Processing info for vLLM registry."""
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.model_config.hf_config
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> KimiAudioProcessor:
|
||||
"""Get KimiAudioProcessor with feature extractor and tokenizer."""
|
||||
# Use vLLM's cached loader for feature extractor
|
||||
feature_extractor = cached_feature_extractor_from_config(
|
||||
self.ctx.model_config,
|
||||
subfolder=KIMIA_WHISPER_SUBFOLDER,
|
||||
)
|
||||
|
||||
# Use vLLM's standard tokenizer loading (respects tokenizer_mode)
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
# Construct processor directly
|
||||
return KimiAudioProcessor(
|
||||
feature_extractor=feature_extractor,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
def get_feature_extractor(self, **kwargs: object):
|
||||
"""Get feature extractor using vLLM's cached loader."""
|
||||
return cached_feature_extractor_from_config(
|
||||
self.ctx.model_config, subfolder=KIMIA_WHISPER_SUBFOLDER
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": 1}
|
||||
|
||||
def get_data_parser(self) -> "KimiAudioMultiModalDataParser":
|
||||
"""Get data parser for audio inputs."""
|
||||
return KimiAudioMultiModalDataParser(
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
|
||||
|
||||
class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo]):
|
||||
"""Dummy inputs builder for vLLM registry."""
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]:
|
||||
"""Return dummy text as token IDs directly."""
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
if num_audios == 0:
|
||||
return [198] # "Transcribe" tokenized
|
||||
# Return as token IDs directly to avoid tokenizer issues
|
||||
return [
|
||||
KimiAudioProcessor.KIMIA_MEDIA_BEGIN,
|
||||
KimiAudioProcessor.KIMIA_TEXT_BLANK,
|
||||
KimiAudioProcessor.KIMIA_MEDIA_END,
|
||||
] * num_audios
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
if num_audios == 0:
|
||||
return {}
|
||||
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
target_audio_length = (
|
||||
min(feature_extractor.chunk_length, 30) * feature_extractor.sampling_rate
|
||||
)
|
||||
|
||||
return {
|
||||
"audio": self._get_dummy_audios(
|
||||
length=target_audio_length, num_audios=num_audios
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# Field config for Kimi-Audio multimodal data
|
||||
_KIMIAUDIO_FIELD_CONFIG = {
|
||||
"whisper_input_features": MultiModalFieldConfig.batched("audio"),
|
||||
"feature_attention_mask": MultiModalFieldConfig.batched("audio"),
|
||||
}
|
||||
|
||||
|
||||
class KimiAudioMultiModalDataParser(MultiModalDataParser):
|
||||
"""Custom data parser for Kimi-Audio multimodal data."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Whisper expects 16kHz audio
|
||||
super().__init__(target_sr=16000, **kwargs)
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="audio",
|
||||
required_fields={"whisper_input_features", "feature_attention_mask"},
|
||||
fields_factory=lambda hf_inputs: _KIMIAUDIO_FIELD_CONFIG,
|
||||
)
|
||||
|
||||
return super()._parse_audio_data(data)
|
||||
|
||||
|
||||
class KimiAudioMultiModalProcessor(BaseMultiModalProcessor[KimiAudioProcessingInfo]):
|
||||
"""vLLM multi-modal processor wrapper for Kimi-Audio."""
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
"""Call the HuggingFace processor."""
|
||||
# Convert mm_data format: {'audios': [...]} -> {'audio': ...}
|
||||
mm_data = dict(mm_data)
|
||||
audios = mm_data.pop("audios", [])
|
||||
|
||||
# Convert audio format: [(array, sr), ...] -> [array, ...]
|
||||
# KimiAudioProcessor expects raw numpy arrays
|
||||
if audios:
|
||||
audio_arrays = []
|
||||
for aud in audios:
|
||||
if isinstance(aud, (tuple, list)) and len(aud) == 2:
|
||||
# Format: (audio_array, sampling_rate)
|
||||
audio_arrays.append(aud[0])
|
||||
elif isinstance(aud, np.ndarray):
|
||||
audio_arrays.append(aud)
|
||||
else:
|
||||
audio_arrays.append(aud)
|
||||
mm_data["audio"] = audio_arrays
|
||||
|
||||
# Use the context's call_hf_processor for proper handling
|
||||
return self.info.ctx.call_hf_processor(
|
||||
self.info.get_hf_processor(**mm_kwargs),
|
||||
dict(text=prompt, **mm_data),
|
||||
dict(**mm_kwargs, **tok_kwargs),
|
||||
)
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, Any]:
|
||||
"""Get multi-modal field configuration."""
|
||||
return _KIMIAUDIO_FIELD_CONFIG
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
out_mm_kwargs,
|
||||
) -> Sequence[PromptReplacement]:
|
||||
"""Get prompt updates for audio tokens."""
|
||||
# Get audio feature lengths from processed output
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
feature_attention_mask = out_mm_data.get("feature_attention_mask")
|
||||
|
||||
if feature_attention_mask is not None:
|
||||
audio_output_lens = _get_feat_extract_output_lengths(
|
||||
feature_attention_mask.sum(-1)
|
||||
)
|
||||
audio_output_lengths = audio_output_lens.tolist()
|
||||
else:
|
||||
audio_output_lengths = []
|
||||
|
||||
def get_replacement_kimiaudio(item_idx: int):
|
||||
num_features = (
|
||||
audio_output_lengths[item_idx]
|
||||
if item_idx < len(audio_output_lengths)
|
||||
else 376
|
||||
)
|
||||
if num_features == 0:
|
||||
num_features = 376 # Default Kimi-Audio sequence length
|
||||
# Return the placeholder token ID repeated num_features times
|
||||
return [KimiAudioProcessor.KIMIA_TEXT_BLANK] * num_features
|
||||
|
||||
# Use the token ID as target (as a list)
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=[KimiAudioProcessor.KIMIA_TEXT_BLANK],
|
||||
replacement=get_replacement_kimiaudio,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Model Definition
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class KimiAudioMultiModalProjector(nn.Module):
|
||||
"""Projects Whisper features to LLM embedding space.
|
||||
|
||||
Kimi-Audio VQ-Adaptor architecture:
|
||||
Custom Whisper (5120) → Linear[5120→3584] → Linear[3584→3584] → LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
whisper_dim: int = 5120, # Kimi-Audio custom Whisper encoder dim
|
||||
llm_dim: int = 3584,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.whisper_dim = whisper_dim
|
||||
self.llm_dim = llm_dim
|
||||
|
||||
# VQ-Adaptor layers (exact checkpoint structure)
|
||||
# layers.0: Linear[5120 → 3584]
|
||||
self.vq_adaptor_layers_0 = nn.Linear(whisper_dim, llm_dim)
|
||||
# layers.3: Linear[3584 → 3584]
|
||||
self.vq_adaptor_layers_3 = nn.Linear(llm_dim, llm_dim)
|
||||
# layers.4: LayerNorm[3584]
|
||||
self.vq_adaptor_layers_4 = nn.LayerNorm(llm_dim)
|
||||
|
||||
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
||||
# Project: [B, T, 5120] → [B, T, 3584]
|
||||
hidden = self.vq_adaptor_layers_0(audio_features)
|
||||
hidden = torch.nn.functional.gelu(hidden)
|
||||
hidden = self.vq_adaptor_layers_3(hidden)
|
||||
hidden = self.vq_adaptor_layers_4(hidden)
|
||||
return hidden
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
KimiAudioMultiModalProcessor,
|
||||
info=KimiAudioProcessingInfo,
|
||||
dummy_inputs=KimiAudioDummyInputsBuilder,
|
||||
)
|
||||
class KimiAudioForConditionalGeneration(
|
||||
nn.Module,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsTranscription,
|
||||
):
|
||||
"""Kimi-Audio model for ASR transcription."""
|
||||
|
||||
# Kimi-Audio supports a subset of Whisper's supported languages
|
||||
supported_languages: ClassVar[Mapping[str, str]] = {
|
||||
k: ISO639_1_SUPPORTED_LANGS[k]
|
||||
for k in ["zh", "en", "ja", "ko", "de", "fr", "es", "it", "pt", "ru", "ar"]
|
||||
}
|
||||
supports_transcription: ClassVar[Literal[True]] = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# Audio projector (VQ-Adaptor)
|
||||
"model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.",
|
||||
"model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.",
|
||||
"model.vq_adaptor.layers.4.": "multi_modal_projector.vq_adaptor_layers_4.",
|
||||
# Language model
|
||||
"model.layers.": "language_model.model.layers.",
|
||||
# Embeddings and output
|
||||
"model.embed_tokens.": "language_model.model.embed_tokens.",
|
||||
"model.norm.": "language_model.model.norm.",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
}
|
||||
)
|
||||
|
||||
# Audio placeholder token sequence
|
||||
AUDIO_PLACEHOLDER = "<|im_media_begin|><|im_kimia_text_blank|><|im_media_end|>"
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
return cls.AUDIO_PLACEHOLDER if modality.startswith("audio") else None
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.model_path = vllm_config.model_config.model
|
||||
|
||||
self.audio_tower = KimiAudioWhisperEncoder(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||
)
|
||||
|
||||
self.multi_modal_projector = KimiAudioMultiModalProjector(
|
||||
whisper_dim=getattr(self.config, "kimia_adaptor_input_dim", 5120),
|
||||
llm_dim=self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config.with_hf_config(
|
||||
self.config, architectures=["Qwen2ForCausalLM"]
|
||||
),
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.config.vocab_size,
|
||||
self.config.vocab_size,
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object
|
||||
) -> dict[str, torch.Tensor] | None:
|
||||
whisper_input_features = kwargs.pop("whisper_input_features", None)
|
||||
if whisper_input_features is None:
|
||||
return None
|
||||
|
||||
return {"whisper_input_features": whisper_input_features}
|
||||
|
||||
def _process_audio_input(
|
||||
self, audio_input: dict[str, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
input_features = audio_input["whisper_input_features"]
|
||||
|
||||
# KimiAudioWhisperEncoder expects list of tensors
|
||||
if input_features.dim() == 3:
|
||||
input_features = input_features.unbind(dim=0)
|
||||
|
||||
# Run through Whisper encoder
|
||||
audio_features = self.audio_tower(input_features)
|
||||
|
||||
# Reshape for 4x downsampling (Whisper outputs at 50Hz, need 12.5Hz)
|
||||
B, T, D = audio_features.shape
|
||||
if T % 4 != 0:
|
||||
pad_len = 4 - (T % 4)
|
||||
audio_features = torch.nn.functional.pad(audio_features, (0, 0, 0, pad_len))
|
||||
T = audio_features.shape[1] # Update T after padding
|
||||
|
||||
audio_features = audio_features.reshape(B, T // 4, D * 4)
|
||||
|
||||
# Project to LLM dimension
|
||||
audio_embeds = self.multi_modal_projector(audio_features)
|
||||
return audio_embeds
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> list[torch.Tensor] | None:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is None:
|
||||
return []
|
||||
|
||||
audio_embeds = self._process_audio_input(audio_input)
|
||||
|
||||
# audio_embeds shape: [batch_size, seq_len, hidden_dim]
|
||||
# Return as list of 2D tensors, one per batch item
|
||||
if audio_embeds.dim() == 3:
|
||||
# Unbind batch dimension: [B, T, D] -> list of B tensors [T, D]
|
||||
return list(audio_embeds.unbind(dim=0))
|
||||
else:
|
||||
# Single sample: [T, D] -> wrap in list
|
||||
return [audio_embeds]
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Embed input IDs and fuse with audio embeddings.
|
||||
|
||||
Kimi-Audio fusion: inputs_embeds = (text_emb + audio_emb) × √2
|
||||
|
||||
For PP compatibility, we use the is_multimodal mask from vLLM engine
|
||||
which is correctly computed per pipeline stage.
|
||||
"""
|
||||
# Get text embeddings
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||
return inputs_embeds
|
||||
|
||||
# is_multimodal must be provided for PP to work correctly
|
||||
if is_multimodal is None or not is_multimodal.any():
|
||||
return inputs_embeds
|
||||
|
||||
# multimodal_embeddings[0] contains audio embeddings
|
||||
audio_embeds = multimodal_embeddings[0]
|
||||
|
||||
# Handle different tensor structures
|
||||
if isinstance(audio_embeds, (list, tuple)):
|
||||
audio_embeds = torch.cat(audio_embeds, dim=0)
|
||||
elif audio_embeds.dim() == 3:
|
||||
audio_embeds = audio_embeds.reshape(-1, audio_embeds.shape[-1])
|
||||
|
||||
# In PP, audio_embeds count should match is_multimodal.sum()
|
||||
# For now, use embeddings sequentially
|
||||
# (works for non-PP, PP needs vLLM infra fix)
|
||||
num_mm_tokens = is_multimodal.sum().item()
|
||||
num_audio_embeds = audio_embeds.shape[0]
|
||||
|
||||
# Use the minimum of available embeddings and positions
|
||||
# This ensures we don't access out-of-bounds
|
||||
num_to_use = min(num_audio_embeds, num_mm_tokens)
|
||||
|
||||
# Get positions for the tokens we'll actually process
|
||||
mm_positions = is_multimodal.nonzero(as_tuple=True)[0]
|
||||
actual_mm_mask = torch.zeros_like(is_multimodal)
|
||||
actual_mm_mask[mm_positions[:num_to_use]] = True
|
||||
|
||||
# Use corresponding embeddings
|
||||
used_audio_embeds = audio_embeds[:num_to_use]
|
||||
|
||||
# Save text embeddings at multimodal positions
|
||||
text_at_mm_positions = inputs_embeds[actual_mm_mask].clone()
|
||||
|
||||
# Replace text with audio at multimodal positions
|
||||
inputs_embeds[actual_mm_mask] = used_audio_embeds.to(dtype=inputs_embeds.dtype)
|
||||
|
||||
# Apply Kimi-Audio's unique fusion formula: (text + audio) × √2
|
||||
inputs_embeds[actual_mm_mask] = (
|
||||
inputs_embeds[actual_mm_mask] + text_at_mm_positions
|
||||
) * (2**0.5)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(
|
||||
self.language_model.lm_head, hidden_states, sampling_metadata
|
||||
)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""Load weights, skipping MIMO layers (TTS-only) for ASR."""
|
||||
# Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
|
||||
skipped_patterns = [
|
||||
"mimo_layers.",
|
||||
"mimo_output.",
|
||||
"mimo_norm.",
|
||||
"audio_decoder.",
|
||||
]
|
||||
|
||||
# Filter weights
|
||||
filtered_weights = [
|
||||
(name, param)
|
||||
for name, param in weights
|
||||
if not any(pattern in name for pattern in skipped_patterns)
|
||||
]
|
||||
|
||||
# Separate main weights (non-Whisper) from Whisper weights
|
||||
main_weights = [
|
||||
(name, param)
|
||||
for name, param in filtered_weights
|
||||
if not name.startswith("audio_tower.")
|
||||
]
|
||||
|
||||
# Load main model weights (LLM + projector) with mapper
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
# Load Whisper encoder weights from subfolder
|
||||
whisper_path = os.path.join(
|
||||
self.model_path, f"{KIMIA_WHISPER_SUBFOLDER}/model.safetensors"
|
||||
)
|
||||
if os.path.exists(whisper_path):
|
||||
whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
|
||||
loaded.update(whisper_loaded)
|
||||
|
||||
return loaded
|
||||
|
||||
def _load_whisper_weights_from_file(self, whisper_path: str) -> set[str]:
|
||||
"""Load Whisper encoder weights from safetensors file with transformations."""
|
||||
if not os.path.exists(whisper_path):
|
||||
return set()
|
||||
|
||||
# Step 1: Load raw weights from safetensors file
|
||||
whisper_weights = []
|
||||
with safe_open(whisper_path, framework="pt") as f:
|
||||
for key in f.keys(): # noqa: SIM118
|
||||
if key.startswith("model.encoder.") and "embed_positions" not in key:
|
||||
new_key = key.replace("model.encoder.", "")
|
||||
whisper_weights.append((new_key, f.get_tensor(key)))
|
||||
|
||||
# Step 2: Apply fc → mlp mapping using WeightsMapper
|
||||
fc_mapper = WeightsMapper(
|
||||
orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
|
||||
)
|
||||
whisper_mapped = list(fc_mapper.apply(whisper_weights))
|
||||
|
||||
# Step 3: Apply Q/K/V fusion manually
|
||||
stacked_params_mapping = [
|
||||
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.audio_tower.named_parameters())
|
||||
whisper_loaded: set[str] = set()
|
||||
|
||||
for name, loaded_weight in whisper_mapped:
|
||||
fused = False
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
fused_name = name.replace(weight_name, param_name)
|
||||
if fused_name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[fused_name]
|
||||
param.weight_loader(param, loaded_weight, shard_id)
|
||||
whisper_loaded.add(f"audio_tower.{fused_name}")
|
||||
fused = True
|
||||
break
|
||||
|
||||
if not fused:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
whisper_loaded.add(f"audio_tower.{name}")
|
||||
|
||||
# Add embed_positions which is initialized randomly
|
||||
whisper_loaded.add("audio_tower.embed_positions.weight")
|
||||
|
||||
return whisper_loaded
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: str
|
||||
) -> SpeechToTextConfig:
|
||||
"""Get speech-to-text config with custom processor."""
|
||||
# Load feature extractor for config values
|
||||
feature_extractor = cached_feature_extractor_from_config(
|
||||
model_config,
|
||||
subfolder=KIMIA_WHISPER_SUBFOLDER,
|
||||
)
|
||||
|
||||
return SpeechToTextConfig(
|
||||
max_audio_clip_s=feature_extractor.chunk_length,
|
||||
sample_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
model_config: ModelConfig,
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: str | None,
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: str | None,
|
||||
) -> PromptType:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
tokenizer_cls=KimiAudioTokenizer,
|
||||
tokenizer_mode=model_config.tokenizer_mode,
|
||||
revision=model_config.tokenizer_revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
if task_type not in ("transcribe", "translate"):
|
||||
raise ValueError(
|
||||
f"Unsupported task_type '{task_type}'. "
|
||||
"Supported task types are 'transcribe' and 'translate'."
|
||||
)
|
||||
|
||||
# Incorporate request_prompt as context/instruction if provided
|
||||
user_content = (
|
||||
f"{request_prompt}\n{cls.AUDIO_PLACEHOLDER}"
|
||||
if request_prompt
|
||||
else cls.AUDIO_PLACEHOLDER
|
||||
)
|
||||
|
||||
prompt = (
|
||||
f"<|im_kimia_user_msg_start|>{user_content}"
|
||||
f"<|im_msg_end|><|im_kimia_assistant_msg_start|>"
|
||||
)
|
||||
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
|
||||
return TokensPrompt(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data={"audio": audio},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def post_process_output(cls, text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
return text.strip()
|
||||
@@ -421,6 +421,7 @@ _MULTIMODAL_MODELS = {
|
||||
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
|
||||
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
|
||||
"KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"), # noqa: E501
|
||||
"MoonshotKimiaForCausalLM": ("kimi_audio", "KimiAudioForConditionalGeneration"), # noqa: E501
|
||||
"LightOnOCRForConditionalGeneration": (
|
||||
"lightonocr",
|
||||
"LightOnOCRForConditionalGeneration",
|
||||
|
||||
49
vllm/renderers/kimi_audio.py
Normal file
49
vllm/renderers/kimi_audio.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, cast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
|
||||
from vllm.tokenizers.registry import get_tokenizer
|
||||
|
||||
from .hf import HfRenderer, HfTokenizer
|
||||
|
||||
|
||||
class KimiAudioRenderer(HfRenderer):
|
||||
"""Renderer for Kimi-Audio models.
|
||||
|
||||
This renderer uses HfRenderer internally with a custom TikToken tokenizer.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
config: VllmConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "HfRenderer":
|
||||
"""Create an HfRenderer instance for Kimi-Audio models."""
|
||||
model_config = config.model_config
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
# Extract tokenizer_name from kwargs (already processed by
|
||||
# tokenizer_args_from_config for ModelScope/GGUF/etc)
|
||||
tokenizer_name = tokenizer_kwargs.pop(
|
||||
"tokenizer_name", model_config.tokenizer
|
||||
)
|
||||
# Remove tokenizer_cls from kwargs to avoid duplicate argument
|
||||
tokenizer_kwargs = {
|
||||
k: v for k, v in tokenizer_kwargs.items() if k != "tokenizer_cls"
|
||||
}
|
||||
# Use get_tokenizer directly instead of cached_get_tokenizer
|
||||
# (KimiAudioTokenizer doesn't work with get_cached_tokenizer)
|
||||
tokenizer = cast(
|
||||
HfTokenizer,
|
||||
get_tokenizer(
|
||||
tokenizer_name,
|
||||
tokenizer_cls=KimiAudioTokenizer, # type: ignore[arg-type]
|
||||
**tokenizer_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
return HfRenderer(config, tokenizer)
|
||||
@@ -19,6 +19,7 @@ _VLLM_RENDERERS = {
|
||||
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
|
||||
"hf": ("hf", "HfRenderer"),
|
||||
"grok2": ("grok2", "Grok2Renderer"),
|
||||
"kimi_audio": ("kimi_audio", "KimiAudioRenderer"),
|
||||
"mistral": ("mistral", "MistralRenderer"),
|
||||
"qwen_vl": ("qwen_vl", "QwenVLRenderer"),
|
||||
"terratorch": ("terratorch", "TerratorchRenderer"),
|
||||
@@ -74,10 +75,18 @@ RENDERER_REGISTRY = RendererRegistry(
|
||||
|
||||
def renderer_from_config(config: "VllmConfig", **kwargs):
|
||||
model_config = config.model_config
|
||||
|
||||
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
|
||||
model_config, **kwargs
|
||||
)
|
||||
|
||||
# Override tokenizer_mode for Kimi-Audio models
|
||||
if model_config.architecture == "MoonshotKimiaForCausalLM":
|
||||
tokenizer_mode = "kimi_audio"
|
||||
# Update model_config so other components (e.g., multimodal registry)
|
||||
# also use the correct tokenizer mode
|
||||
model_config.tokenizer_mode = "kimi_audio"
|
||||
|
||||
if (
|
||||
model_config.tokenizer_mode == "auto"
|
||||
and model_config.model_impl == "terratorch"
|
||||
|
||||
410
vllm/tokenizers/kimi_audio.py
Normal file
410
vllm/tokenizers/kimi_audio.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tokenizer for Kimi-Audio using TikToken."""
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, overload
|
||||
|
||||
import pybase64
|
||||
import tiktoken
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AddedToken, BatchEncoding
|
||||
from transformers.utils import chat_template_utils as hf_chat_utils
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers.protocol import TokenizerLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _load_tiktoken_encoding(
|
||||
vocab_file: Path, special_tokens: dict[str, int]
|
||||
) -> tuple[Any, dict[str, int]]:
|
||||
"""Load TikToken encoding from vocab file."""
|
||||
mergeable_ranks: dict[bytes, int] = {}
|
||||
with open(vocab_file, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split()
|
||||
if len(parts) == 2:
|
||||
token_b64 = parts[0]
|
||||
rank = int(parts[1])
|
||||
token_bytes = pybase64.b64decode(token_b64)
|
||||
mergeable_ranks[token_bytes] = rank
|
||||
|
||||
tokenizer = tiktoken.Encoding(
|
||||
name=str(vocab_file),
|
||||
pat_str=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|"""
|
||||
r""" ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
return tokenizer, special_tokens
|
||||
|
||||
|
||||
class KimiAudioTokenizer(TokenizerLike):
|
||||
"""TikToken tokenizer for Kimi-Audio."""
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
path_or_repo_id: str | Path,
|
||||
*args,
|
||||
trust_remote_code: bool = False,
|
||||
revision: str | None = None,
|
||||
download_dir: str | None = None,
|
||||
**kwargs,
|
||||
) -> "KimiAudioTokenizer":
|
||||
if args:
|
||||
logger.debug_once("Ignoring extra positional args for KimiAudioTokenizer.")
|
||||
|
||||
path = Path(path_or_repo_id)
|
||||
if path.is_file():
|
||||
vocab_file = path
|
||||
elif path.is_dir():
|
||||
vocab_file = path / "tiktoken.model"
|
||||
if not vocab_file.is_file():
|
||||
vocab_file = path / "tokenizer.model"
|
||||
else:
|
||||
# Download from HuggingFace Hub
|
||||
repo_id = str(path_or_repo_id)
|
||||
|
||||
# Try to download tiktoken.model or tokenizer.model
|
||||
try:
|
||||
vocab_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="tiktoken.model",
|
||||
revision=revision,
|
||||
local_dir=download_dir,
|
||||
)
|
||||
vocab_file = Path(vocab_path)
|
||||
except Exception:
|
||||
try:
|
||||
vocab_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="tokenizer.model",
|
||||
revision=revision,
|
||||
local_dir=download_dir,
|
||||
)
|
||||
vocab_file = Path(vocab_path)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Could not find tiktoken.model or tokenizer.model in {repo_id}"
|
||||
) from exc
|
||||
|
||||
# Also download tokenizer_config.json if available
|
||||
with contextlib.suppress(Exception):
|
||||
hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="tokenizer_config.json",
|
||||
revision=revision,
|
||||
local_dir=download_dir,
|
||||
)
|
||||
|
||||
if not vocab_file.is_file():
|
||||
raise FileNotFoundError(f"tiktoken.model not found at {vocab_file}.")
|
||||
|
||||
return cls(
|
||||
vocab_file=vocab_file,
|
||||
name_or_path=str(path_or_repo_id),
|
||||
truncation_side=kwargs.get("truncation_side", "left"),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vocab_file: Path,
|
||||
name_or_path: str,
|
||||
truncation_side: str,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.name_or_path = name_or_path
|
||||
self._truncation_side = truncation_side
|
||||
self._vocab_file = vocab_file
|
||||
|
||||
# Load special tokens from tokenizer_config.json
|
||||
special_tokens: dict[str, int] = {}
|
||||
tokenizer_config = vocab_file.parent / "tokenizer_config.json"
|
||||
if tokenizer_config.is_file():
|
||||
with open(tokenizer_config, encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
# Extract special tokens from added_tokens_decoder
|
||||
added_tokens = config.get("added_tokens_decoder", {})
|
||||
for token_id_str, token_info in added_tokens.items():
|
||||
token_id = int(token_id_str)
|
||||
content = token_info.get("content", "")
|
||||
if content:
|
||||
special_tokens[content] = token_id
|
||||
|
||||
self._tokenizer, self._special_tokens = _load_tiktoken_encoding(
|
||||
vocab_file, special_tokens
|
||||
)
|
||||
|
||||
# Build token <-> ID mappings
|
||||
self._token_to_id: dict[str, int] = {}
|
||||
self._id_to_token: dict[int, str] = {}
|
||||
for token_bytes, token_id in self._tokenizer._mergeable_ranks.items():
|
||||
token_str = token_bytes.decode("utf-8", errors="replace")
|
||||
self._token_to_id[token_str] = token_id
|
||||
self._id_to_token[token_id] = token_str
|
||||
|
||||
# Initialize added_tokens_decoder before adding special tokens
|
||||
self._added_tokens_decoder: dict[int, Any] = {}
|
||||
|
||||
# Add Kimi-Audio special tokens
|
||||
self._add_kimiaudio_special_tokens()
|
||||
|
||||
# Set default special token IDs (will be updated when special tokens are added)
|
||||
self._bos_token_id = 151643 # Kimi-Audio BOS
|
||||
self._eos_token_id = 151644 # Kimi-Audio EOS
|
||||
self._pad_token_id = self._eos_token_id
|
||||
self._unk_token_id = self._pad_token_id
|
||||
|
||||
self._max_chars_per_token = max(
|
||||
(len(tok) for tok in self._token_to_id), default=10
|
||||
)
|
||||
|
||||
def _add_kimiaudio_special_tokens(self) -> None:
|
||||
"""Add Kimi-Audio special tokens to the tokenizer."""
|
||||
# Tokens should already be in self._special_tokens from tokenizer_config.json
|
||||
# Just add them to added_tokens_decoder for compatibility
|
||||
kimiaudio_special_tokens = {
|
||||
"<|im_media_begin|>": 151661,
|
||||
"<|im_media_end|>": 151663,
|
||||
"<|im_kimia_text_blank|>": 151666,
|
||||
"<|im_msg_end|>": 151645,
|
||||
"<|im_kimia_user_msg_start|>": 151670,
|
||||
"<|im_kimia_assistant_msg_start|>": 151671,
|
||||
}
|
||||
|
||||
for token_str, token_id in kimiaudio_special_tokens.items():
|
||||
# Only add if not already present
|
||||
if token_id not in self._added_tokens_decoder:
|
||||
self._added_tokens_decoder[token_id] = AddedToken(
|
||||
token_str, single_word=True, normalized=False, special=True
|
||||
)
|
||||
# Also ensure it's in _token_to_id and _id_to_token
|
||||
if token_str not in self._token_to_id:
|
||||
self._token_to_id[token_str] = token_id
|
||||
if token_id not in self._id_to_token:
|
||||
self._id_to_token[token_id] = token_str
|
||||
|
||||
def num_special_tokens_to_add(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def all_special_tokens(self) -> list[str]:
|
||||
return list(self._added_tokens_decoder.values())
|
||||
|
||||
@property
|
||||
def all_special_ids(self) -> list[int]:
|
||||
return list(self._added_tokens_decoder.keys())
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
return self._bos_token_id
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return self._eos_token_id
|
||||
|
||||
@property
|
||||
def pad_token_id(self) -> int:
|
||||
return self._pad_token_id
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self._tokenizer.n_vocab
|
||||
|
||||
@property
|
||||
def max_token_id(self) -> int:
|
||||
return self._tokenizer.n_vocab - 1
|
||||
|
||||
@property
|
||||
def max_chars_per_token(self) -> int:
|
||||
return self._max_chars_per_token
|
||||
|
||||
@property
|
||||
def truncation_side(self) -> str:
|
||||
return self._truncation_side
|
||||
|
||||
@property
|
||||
def added_tokens_decoder(self) -> dict[int, Any]:
|
||||
return self._added_tokens_decoder
|
||||
|
||||
@added_tokens_decoder.setter
|
||||
def added_tokens_decoder(self, value: dict[int, Any]) -> None:
|
||||
"""Set added tokens decoder and update special token IDs."""
|
||||
self._added_tokens_decoder = value
|
||||
# Update special token IDs if known tokens are added
|
||||
for token_id, token in value.items():
|
||||
token_str = str(token) if hasattr(token, "__str__") else token
|
||||
if "<|im_kimia_user_msg_start|>" in token_str:
|
||||
self._bos_token_id = token_id
|
||||
elif "<|im_msg_end|>" in token_str or "<|im_end|>" in token_str:
|
||||
self._eos_token_id = token_id
|
||||
|
||||
def get_vocab(self) -> dict[str, int]:
|
||||
return dict(self._token_to_id)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return vocab size for compatibility with HF tokenizer interface."""
|
||||
return self._tokenizer.n_vocab
|
||||
|
||||
def get_added_vocab(self) -> dict[str, int]:
|
||||
return {
|
||||
str(token): token_id
|
||||
for token_id, token in self._added_tokens_decoder.items()
|
||||
}
|
||||
|
||||
def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]:
|
||||
if max_length is None or len(tokens) <= max_length:
|
||||
return tokens
|
||||
if self.truncation_side == "left":
|
||||
return tokens[-max_length:]
|
||||
return tokens[:max_length]
|
||||
|
||||
def encode(
|
||||
self,
|
||||
text: str,
|
||||
truncation: bool | None = None,
|
||||
max_length: int | None = None,
|
||||
add_special_tokens: bool = True,
|
||||
**kwargs,
|
||||
) -> list[int]:
|
||||
del add_special_tokens
|
||||
# Allow Kimi-Audio special tokens to be encoded
|
||||
tokens = self._tokenizer.encode(
|
||||
text,
|
||||
allowed_special={
|
||||
"<|im_media_begin|>",
|
||||
"<|im_media_end|>",
|
||||
"<|im_kimia_text_blank|>",
|
||||
"<|im_msg_end|>",
|
||||
"<|im_kimia_user_msg_start|>",
|
||||
"<|im_kimia_assistant_msg_start|>",
|
||||
},
|
||||
)
|
||||
if truncation:
|
||||
tokens = self._maybe_truncate(tokens, max_length)
|
||||
return tokens
|
||||
|
||||
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
|
||||
"""Decode token IDs to text, optionally skipping special tokens."""
|
||||
if isinstance(ids, int):
|
||||
ids = [ids]
|
||||
if skip_special_tokens:
|
||||
# Skip tokens that are in special_tokens (loaded from config)
|
||||
special_ids = set(self._special_tokens.values())
|
||||
ids = [token_id for token_id in ids if token_id not in special_ids]
|
||||
return self._tokenizer.decode(ids)
|
||||
|
||||
@overload
|
||||
def convert_tokens_to_ids(self, tokens: str) -> int: ...
|
||||
|
||||
@overload
|
||||
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
|
||||
|
||||
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
|
||||
if isinstance(tokens, str):
|
||||
return self._token_to_id.get(tokens, self._unk_token_id)
|
||||
return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self, ids: list[int], skip_special_tokens: bool = False
|
||||
) -> list[str]:
|
||||
tokens = []
|
||||
for token_id in ids:
|
||||
if skip_special_tokens and token_id in self._added_tokens_decoder:
|
||||
continue
|
||||
tokens.append(self._id_to_token.get(token_id, "<|unk|>"))
|
||||
return tokens
|
||||
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
token_ids = self.convert_tokens_to_ids(tokens)
|
||||
return self.decode(token_ids, skip_special_tokens=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: str | list[str],
|
||||
text_pair: str | None = None,
|
||||
add_special_tokens: bool = True,
|
||||
truncation: bool = False,
|
||||
max_length: int | None = None,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
if text_pair is not None:
|
||||
raise NotImplementedError(
|
||||
"text_pair is not supported for KimiAudioTokenizer."
|
||||
)
|
||||
|
||||
if isinstance(text, list):
|
||||
input_ids_batch: list[list[int]] = [
|
||||
self.encode(
|
||||
item,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
for item in text
|
||||
]
|
||||
attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch]
|
||||
return BatchEncoding(
|
||||
{"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
|
||||
)
|
||||
|
||||
input_ids = self.encode(
|
||||
text,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})
|
||||
|
||||
def get_chat_template(
|
||||
self, chat_template: str | None, tools: list[dict[str, Any]] | None = None
|
||||
) -> str | None:
|
||||
del tools
|
||||
return chat_template
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
chat_template: str | None = None,
|
||||
tokenize: bool = False,
|
||||
**kwargs,
|
||||
) -> str | list[int]:
|
||||
# Handle both 'messages' (protocol) and 'conversation' (caller) parameter names
|
||||
conversation = messages if messages is not None else kwargs.get("conversation")
|
||||
if conversation is None:
|
||||
raise ValueError("Either 'messages' or 'conversation' must be provided.")
|
||||
template = self.get_chat_template(chat_template, tools=tools)
|
||||
if template is None:
|
||||
raise ValueError(
|
||||
"No chat template available. Provide `chat_template` explicitly."
|
||||
)
|
||||
# Use render_jinja_template instead of apply_chat_template
|
||||
# Note: render_jinja_template returns ([prompts], [generation_indices])
|
||||
rendered, _ = hf_chat_utils.render_jinja_template(
|
||||
conversation,
|
||||
chat_template=template,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
# Extract the first (and usually only) prompt
|
||||
prompt = rendered[0] if rendered else ""
|
||||
if tokenize:
|
||||
return self.encode(prompt, add_special_tokens=False)
|
||||
return prompt
|
||||
@@ -35,6 +35,7 @@ _VLLM_TOKENIZERS = {
|
||||
"deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
|
||||
"grok2": ("grok2", "Grok2Tokenizer"),
|
||||
"hf": ("hf", "CachedHfTokenizer"),
|
||||
"kimi_audio": ("kimi_audio", "KimiAudioTokenizer"),
|
||||
"mistral": ("mistral", "MistralTokenizer"),
|
||||
"qwen_vl": ("qwen_vl", "QwenVLTokenizer"),
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
{% set messages = conversations[0] if conversations else [] -%}
|
||||
{% if messages and messages[0]['role'] == 'system' -%}
|
||||
{% set loop_messages = messages[1:] -%}
|
||||
{% else -%}
|
||||
{% set loop_messages = messages -%}
|
||||
{% endif -%}
|
||||
{% for message in loop_messages -%}
|
||||
{% if message['role'] == 'user' -%}
|
||||
<|im_kimia_user_msg_start|>{{ message['content'] }}<|im_msg_end|><|im_kimia_assistant_msg_start|>
|
||||
{%- elif message['role'] == 'assistant' -%}
|
||||
{{ message['content'] }}<|im_kimia_text_eos|>
|
||||
{%- endif -%}
|
||||
{% endfor -%}
|
||||
@@ -10,23 +10,6 @@ reasons:
|
||||
|
||||
import importlib
|
||||
|
||||
_CLASS_TO_MODULE: dict[str, str] = {
|
||||
"BagelProcessor": "vllm.transformers_utils.processors.bagel",
|
||||
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
|
||||
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",
|
||||
"FunASRProcessor": "vllm.transformers_utils.processors.funasr",
|
||||
"GLM4VProcessor": "vllm.transformers_utils.processors.glm4v",
|
||||
"HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl",
|
||||
"HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image",
|
||||
"MistralCommonPixtralProcessor": "vllm.transformers_utils.processors.pixtral",
|
||||
"MistralCommonVoxtralProcessor": "vllm.transformers_utils.processors.voxtral",
|
||||
"OvisProcessor": "vllm.transformers_utils.processors.ovis",
|
||||
"Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5",
|
||||
"QwenVLProcessor": "vllm.transformers_utils.processors.qwen_vl",
|
||||
"Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr",
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BagelProcessor",
|
||||
"DeepseekVLV2Processor",
|
||||
@@ -35,6 +18,7 @@ __all__ = [
|
||||
"GLM4VProcessor",
|
||||
"HunYuanVLProcessor",
|
||||
"HunYuanVLImageProcessor",
|
||||
"KimiAudioProcessor",
|
||||
"MistralCommonPixtralProcessor",
|
||||
"MistralCommonVoxtralProcessor",
|
||||
"OvisProcessor",
|
||||
@@ -43,6 +27,23 @@ __all__ = [
|
||||
"Qwen3ASRProcessor",
|
||||
]
|
||||
|
||||
_CLASS_TO_MODULE: dict[str, str] = {
|
||||
"BagelProcessor": "vllm.transformers_utils.processors.bagel",
|
||||
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
|
||||
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",
|
||||
"FunASRProcessor": "vllm.transformers_utils.processors.funasr",
|
||||
"GLM4VProcessor": "vllm.transformers_utils.processors.glm4v",
|
||||
"HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl",
|
||||
"HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image",
|
||||
"KimiAudioProcessor": "vllm.transformers_utils.processors.kimi_audio",
|
||||
"MistralCommonPixtralProcessor": "vllm.transformers_utils.processors.pixtral",
|
||||
"MistralCommonVoxtralProcessor": "vllm.transformers_utils.processors.voxtral",
|
||||
"OvisProcessor": "vllm.transformers_utils.processors.ovis",
|
||||
"Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5",
|
||||
"QwenVLProcessor": "vllm.transformers_utils.processors.qwen_vl",
|
||||
"Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name in _CLASS_TO_MODULE:
|
||||
|
||||
163
vllm/transformers_utils/processors/kimi_audio.py
Normal file
163
vllm/transformers_utils/processors/kimi_audio.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ruff: noqa
|
||||
# mypy: ignore-errors
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Moonshot AI team and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Processor for Kimi-Audio ASR model."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoFeatureExtractor, BatchFeature, ProcessorMixin
|
||||
from transformers.audio_utils import AudioInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
|
||||
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute output lengths after Whisper feature extraction."""
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = (
|
||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
)
|
||||
return output_lengths
|
||||
|
||||
|
||||
class KimiAudioProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Kimi-Audio processor.
|
||||
|
||||
[`KimiAudioProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and a tokenizer.
|
||||
See the [`~KimiAudioProcessor.__call__`] and [`~KimiAudioProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
|
||||
The audio feature extractor.
|
||||
tokenizer ([`PreTrainedTokenizer`], *optional*):
|
||||
The text tokenizer.
|
||||
"""
|
||||
|
||||
# Required for ProcessorMixin
|
||||
attributes = ["feature_extractor", "tokenizer"]
|
||||
feature_extractor_class = "AutoFeatureExtractor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
|
||||
# Special token IDs
|
||||
KIMIA_MEDIA_BEGIN: int = 151661
|
||||
KIMIA_MEDIA_END: int = 151663
|
||||
KIMIA_TEXT_BLANK: int = 151666
|
||||
|
||||
# Audio processing constants
|
||||
AUDIO_SEQ_LEN: int = 376
|
||||
|
||||
def __init__(self, feature_extractor=None, tokenizer=None, **kwargs):
|
||||
# Pass feature_extractor and tokenizer to parent ProcessorMixin
|
||||
super().__init__(
|
||||
feature_extractor=feature_extractor,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def check_argument_for_proper_class(self, attribute_name: str, argument: Any):
|
||||
"""Override to skip class validation for custom tokenizer."""
|
||||
# Skip validation for tokenizer since KimiAudioTokenizer doesn't inherit
|
||||
# from PreTrainedTokenizerBase but is compatible
|
||||
if attribute_name == "tokenizer" and argument is not None:
|
||||
return
|
||||
# For other attributes, use default validation
|
||||
super().check_argument_for_proper_class(attribute_name, argument)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput = None,
|
||||
audio: AudioInput = None,
|
||||
return_tensors: str = "pt",
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and audio(s).
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`):
|
||||
The sequence or batch of sequences to be encoded.
|
||||
audio (`np.ndarray`, `List[np.ndarray]`):
|
||||
The audio or batch of audio to be prepared. Each audio can be a NumPy array.
|
||||
return_tensors (`str`):
|
||||
The type of tensors to return ("pt", "np", etc.)
|
||||
"""
|
||||
if text is None:
|
||||
raise ValueError("You need to specify either a `text` input to process.")
|
||||
|
||||
# Process audio if provided
|
||||
if audio is not None:
|
||||
# Ensure audio is a list
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = [audio]
|
||||
|
||||
# Pad audio to hop length (required by WhisperFeatureExtractor)
|
||||
hop_length = self.feature_extractor.hop_length
|
||||
padded_audio = []
|
||||
for aud in audio:
|
||||
length = aud.shape[-1]
|
||||
if length % hop_length != 0:
|
||||
pad_length = hop_length - (length % hop_length)
|
||||
aud = np.pad(
|
||||
aud, (0, pad_length), mode="constant", constant_values=0
|
||||
)
|
||||
padded_audio.append(aud)
|
||||
|
||||
# Use feature_extractor directly like Qwen3ASR does
|
||||
audio_inputs = self.feature_extractor(
|
||||
padded_audio,
|
||||
sampling_rate=16000,
|
||||
padding=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
# Rename to match Kimi-Audio expectations
|
||||
if "input_features" in audio_inputs:
|
||||
audio_inputs["whisper_input_features"] = audio_inputs.pop(
|
||||
"input_features"
|
||||
)
|
||||
if "attention_mask" in audio_inputs:
|
||||
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
|
||||
"attention_mask"
|
||||
)
|
||||
else:
|
||||
audio_inputs = {}
|
||||
|
||||
# Handle text input - can be string or token IDs from vLLM processor
|
||||
if isinstance(text, list) and len(text) > 0 and isinstance(text[0], int):
|
||||
# Text is already token IDs (from vLLM processor) - just wrap
|
||||
text_inputs = {"input_ids": torch.tensor([text], dtype=torch.long)}
|
||||
else:
|
||||
# Text is string - tokenize
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text, return_tensors=return_tensors, padding=True
|
||||
)
|
||||
|
||||
return BatchFeature(
|
||||
data={**text_inputs, **audio_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
Reference in New Issue
Block a user