[Model] Add support for moonshotai/Kimi-Audio-7B-Instruct (#36127)
Signed-off-by: tunglinwood <tunglinwood@gmail.com> Signed-off-by: tunglinwood <tomwu.tunglin@gmail.com> Signed-off-by: tunglinwood <113751333+tunglinwood@users.noreply.github.com>
This commit is contained in:
@@ -713,8 +713,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
|||||||
| `KananaVForConditionalGeneration` | Kanana-V | T + I<sup>+</sup> | `kakaocorp/kanana-1.5-v-3b-instruct`, etc. | | ✅︎ |
|
| `KananaVForConditionalGeneration` | Kanana-V | T + I<sup>+</sup> | `kakaocorp/kanana-1.5-v-3b-instruct`, etc. | | ✅︎ |
|
||||||
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ |
|
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ |
|
||||||
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
|
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
|
||||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
|
| `KimiAudioForConditionalGeneration` | Kimi-Audio | T + A<sup>+</sup> | `moonshotai/Kimi-Audio-7B-Instruct` | | ✅︎ |
|
||||||
| `KimiK25ForConditionalGeneration` | Kimi-K2.5 | T + I<sup>+</sup> | `moonshotai/Kimi-K2.5` | | ✅︎ |
|
| `KimiK25ForConditionalGeneration` | Kimi-K2.5 | T + I<sup>+</sup> | `moonshotai/Kimi-K2.5` | | ✅︎ |
|
||||||
|
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
|
||||||
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I<sup>+</sup> | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
|
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I<sup>+</sup> | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
|
||||||
| `Lfm2VlForConditionalGeneration` | LFM2-VL | T + I<sup>+</sup> | `LiquidAI/LFM2-VL-450M`, `LiquidAI/LFM2-VL-3B`, `LiquidAI/LFM2-VL-8B-A1B`, etc. | ✅︎ | ✅︎ |
|
| `Lfm2VlForConditionalGeneration` | LFM2-VL | T + I<sup>+</sup> | `LiquidAI/LFM2-VL-450M`, `LiquidAI/LFM2-VL-3B`, `LiquidAI/LFM2-VL-8B-A1B`, etc. | ✅︎ | ✅︎ |
|
||||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ |
|
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@@ -201,6 +201,34 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Kimi-Audio-7B-Instruct
|
||||||
|
def run_kimi_audio(question: str, audio_count: int) -> ModelRequestData:
|
||||||
|
"""Kimi-Audio-7B-Instruct for audio transcription and understanding."""
|
||||||
|
model_name = "moonshotai/Kimi-Audio-7B-Instruct"
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=2,
|
||||||
|
limit_mm_per_prompt={"audio": audio_count},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Kimi-Audio uses <|im_kimia_text_blank|> as placeholder for audio features
|
||||||
|
audio_placeholder = "<|im_kimia_text_blank|>" * audio_count
|
||||||
|
# Default prompt for transcription
|
||||||
|
if not question:
|
||||||
|
question = "Please transcribe the audio"
|
||||||
|
prompt = f"{audio_placeholder}{question}"
|
||||||
|
|
||||||
|
# Stop at EOS token (151644) to prevent repetition
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompt=prompt,
|
||||||
|
stop_token_ids=[151644],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# MiDashengLM
|
# MiDashengLM
|
||||||
def run_midashenglm(question: str, audio_count: int):
|
def run_midashenglm(question: str, audio_count: int):
|
||||||
model_name = "mispeech/midashenglm-7b"
|
model_name = "mispeech/midashenglm-7b"
|
||||||
@@ -485,6 +513,7 @@ model_example_map = {
|
|||||||
"glmasr": run_glmasr,
|
"glmasr": run_glmasr,
|
||||||
"funaudiochat": run_funaudiochat,
|
"funaudiochat": run_funaudiochat,
|
||||||
"granite_speech": run_granite_speech,
|
"granite_speech": run_granite_speech,
|
||||||
|
"kimi_audio": run_kimi_audio,
|
||||||
"midashenglm": run_midashenglm,
|
"midashenglm": run_midashenglm,
|
||||||
"minicpmo": run_minicpmo,
|
"minicpmo": run_minicpmo,
|
||||||
"phi4_mm": run_phi4mm,
|
"phi4_mm": run_phi4mm,
|
||||||
|
|||||||
@@ -198,13 +198,17 @@ def get_text_token_prompts(
|
|||||||
mm_counts,
|
mm_counts,
|
||||||
mm_options={},
|
mm_options={},
|
||||||
)
|
)
|
||||||
assert isinstance(inputs.prompt, str)
|
# Some models (e.g., Kimi-Audio) return token IDs directly instead of str
|
||||||
|
if isinstance(inputs.prompt, list):
|
||||||
text_prompt = inputs.prompt
|
text_prompt = None
|
||||||
token_prompt = tokenizer.encode(
|
token_prompt = inputs.prompt
|
||||||
text_prompt,
|
else:
|
||||||
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type, True),
|
assert isinstance(inputs.prompt, str)
|
||||||
)
|
text_prompt = inputs.prompt
|
||||||
|
token_prompt = tokenizer.encode(
|
||||||
|
text_prompt,
|
||||||
|
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type, True),
|
||||||
|
)
|
||||||
|
|
||||||
return text_prompt, token_prompt
|
return text_prompt, token_prompt
|
||||||
|
|
||||||
|
|||||||
@@ -857,6 +857,15 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"Kwai-Keye/Keye-VL-1_5-8B",
|
"Kwai-Keye/Keye-VL-1_5-8B",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
),
|
),
|
||||||
|
"MoonshotKimiaForCausalLM": _HfExamplesInfo(
|
||||||
|
"moonshotai/Kimi-Audio-7B-Instruct",
|
||||||
|
tokenizer_mode="kimi_audio",
|
||||||
|
trust_remote_code=True,
|
||||||
|
),
|
||||||
|
"KimiK25ForConditionalGeneration": _HfExamplesInfo(
|
||||||
|
"moonshotai/Kimi-K2.5",
|
||||||
|
trust_remote_code=True,
|
||||||
|
),
|
||||||
"KimiVLForConditionalGeneration": _HfExamplesInfo(
|
"KimiVLForConditionalGeneration": _HfExamplesInfo(
|
||||||
"moonshotai/Kimi-VL-A3B-Instruct",
|
"moonshotai/Kimi-VL-A3B-Instruct",
|
||||||
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"},
|
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"},
|
||||||
@@ -870,10 +879,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
"KimiK25ForConditionalGeneration": _HfExamplesInfo(
|
|
||||||
"moonshotai/Kimi-K2.5",
|
|
||||||
trust_remote_code=True,
|
|
||||||
),
|
|
||||||
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
|
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
|
||||||
"lightonai/LightOnOCR-1B-1025"
|
"lightonai/LightOnOCR-1B-1025"
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -103,6 +103,12 @@ def can_initialize(
|
|||||||
"pickle error when loading `transformers.models.auto.CONFIG_MAPPING`"
|
"pickle error when loading `transformers.models.auto.CONFIG_MAPPING`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_arch == "MoonshotKimiaForCausalLM":
|
||||||
|
pytest.skip(
|
||||||
|
"Kimi-Audio requires SpeechToTextConfig "
|
||||||
|
"which is not configured in test environment"
|
||||||
|
)
|
||||||
|
|
||||||
if model_arch in ["DeepseekV32ForCausalLM", "GlmMoeDsaForCausalLM"]:
|
if model_arch in ["DeepseekV32ForCausalLM", "GlmMoeDsaForCausalLM"]:
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|||||||
725
vllm/model_executor/models/kimi_audio.py
Normal file
725
vllm/model_executor/models/kimi_audio.py
Normal file
@@ -0,0 +1,725 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""Inference-only Kimi-Audio model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
|
from typing import Any, ClassVar, Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from safetensors import safe_open
|
||||||
|
from transformers import BatchFeature
|
||||||
|
from transformers import WhisperConfig as HFWhisperConfig
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||||
|
from vllm.inputs.data import PromptType, TokensPrompt
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.models.interfaces import (
|
||||||
|
SupportsMultiModal,
|
||||||
|
SupportsPP,
|
||||||
|
SupportsTranscription,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.models.utils import (
|
||||||
|
AutoWeightsLoader,
|
||||||
|
WeightsMapper,
|
||||||
|
init_vllm_registered_model,
|
||||||
|
maybe_prefix,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.models.whisper import WhisperEncoder
|
||||||
|
from vllm.model_executor.models.whisper_utils import ISO639_1_SUPPORTED_LANGS
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||||
|
from vllm.multimodal.parse import (
|
||||||
|
AudioItem,
|
||||||
|
DictEmbeddingItems,
|
||||||
|
ModalityData,
|
||||||
|
ModalityDataItems,
|
||||||
|
MultiModalDataParser,
|
||||||
|
)
|
||||||
|
from vllm.multimodal.processing import (
|
||||||
|
BaseDummyInputsBuilder,
|
||||||
|
BaseProcessingInfo,
|
||||||
|
PromptReplacement,
|
||||||
|
)
|
||||||
|
from vllm.multimodal.processing.processor import BaseMultiModalProcessor
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.tokenizers import cached_get_tokenizer
|
||||||
|
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
|
||||||
|
from vllm.transformers_utils.processor import cached_feature_extractor_from_config
|
||||||
|
from vllm.transformers_utils.processors.kimi_audio import KimiAudioProcessor
|
||||||
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
|
# Kimi-Audio constants
|
||||||
|
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute output lengths after Whisper feature extraction.
|
||||||
|
|
||||||
|
Whisper processes audio through multiple conv layers with stride=2,
|
||||||
|
producing 13 output features per 100 input samples.
|
||||||
|
"""
|
||||||
|
input_lengths_leave = input_lengths % 100
|
||||||
|
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||||
|
output_lengths = (
|
||||||
|
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||||
|
)
|
||||||
|
return output_lengths
|
||||||
|
|
||||||
|
|
||||||
|
class KimiAudioWhisperEncoder(WhisperEncoder):
|
||||||
|
"""WhisperEncoder for Kimi-Audio with packed_modules_mapping."""
|
||||||
|
|
||||||
|
# packed_modules_mapping for Q/K/V fusion during weight loading
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
|
"kv_proj": ["k_proj", "v_proj"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
|
||||||
|
):
|
||||||
|
# Load Whisper config from subfolder (authoritative source)
|
||||||
|
# Kimi-Audio stores Whisper config in whisper-large-v3/config.json
|
||||||
|
model_path = vllm_config.model_config.model
|
||||||
|
whisper_config_path = os.path.join(model_path, KIMIA_WHISPER_SUBFOLDER)
|
||||||
|
|
||||||
|
# Load WhisperConfig from the subfolder
|
||||||
|
whisper_config = HFWhisperConfig.from_pretrained(whisper_config_path)
|
||||||
|
|
||||||
|
# Temporarily replace hf_config for WhisperEncoder.__init__()
|
||||||
|
original_config = vllm_config.model_config.hf_config
|
||||||
|
vllm_config.model_config.hf_config = whisper_config
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
vllm_config=vllm_config, prefix=prefix, init_in_fp32=init_in_fp32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Restore original config
|
||||||
|
vllm_config.model_config.hf_config = original_config
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Processing Info, Dummy Inputs, and MultiModal Processor
|
||||||
|
# (Following Qwen3ASR pattern - same file as model)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class KimiAudioProcessingInfo(BaseProcessingInfo):
|
||||||
|
"""Processing info for vLLM registry."""
|
||||||
|
|
||||||
|
def get_hf_config(self):
|
||||||
|
return self.ctx.model_config.hf_config
|
||||||
|
|
||||||
|
def get_hf_processor(self, **kwargs: object) -> KimiAudioProcessor:
|
||||||
|
"""Get KimiAudioProcessor with feature extractor and tokenizer."""
|
||||||
|
# Use vLLM's cached loader for feature extractor
|
||||||
|
feature_extractor = cached_feature_extractor_from_config(
|
||||||
|
self.ctx.model_config,
|
||||||
|
subfolder=KIMIA_WHISPER_SUBFOLDER,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use vLLM's standard tokenizer loading (respects tokenizer_mode)
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
# Construct processor directly
|
||||||
|
return KimiAudioProcessor(
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_feature_extractor(self, **kwargs: object):
|
||||||
|
"""Get feature extractor using vLLM's cached loader."""
|
||||||
|
return cached_feature_extractor_from_config(
|
||||||
|
self.ctx.model_config, subfolder=KIMIA_WHISPER_SUBFOLDER
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||||
|
return {"audio": 1}
|
||||||
|
|
||||||
|
def get_data_parser(self) -> "KimiAudioMultiModalDataParser":
|
||||||
|
"""Get data parser for audio inputs."""
|
||||||
|
return KimiAudioMultiModalDataParser(
|
||||||
|
expected_hidden_size=self._get_expected_hidden_size(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo]):
|
||||||
|
"""Dummy inputs builder for vLLM registry."""
|
||||||
|
|
||||||
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]:
|
||||||
|
"""Return dummy text as token IDs directly."""
|
||||||
|
num_audios = mm_counts.get("audio", 0)
|
||||||
|
if num_audios == 0:
|
||||||
|
return [198] # "Transcribe" tokenized
|
||||||
|
# Return as token IDs directly to avoid tokenizer issues
|
||||||
|
return [
|
||||||
|
KimiAudioProcessor.KIMIA_MEDIA_BEGIN,
|
||||||
|
KimiAudioProcessor.KIMIA_TEXT_BLANK,
|
||||||
|
KimiAudioProcessor.KIMIA_MEDIA_END,
|
||||||
|
] * num_audios
|
||||||
|
|
||||||
|
def get_dummy_mm_data(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
mm_options: Mapping[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
num_audios = mm_counts.get("audio", 0)
|
||||||
|
if num_audios == 0:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
feature_extractor = self.info.get_feature_extractor()
|
||||||
|
target_audio_length = (
|
||||||
|
min(feature_extractor.chunk_length, 30) * feature_extractor.sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"audio": self._get_dummy_audios(
|
||||||
|
length=target_audio_length, num_audios=num_audios
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Field config for Kimi-Audio multimodal data
|
||||||
|
_KIMIAUDIO_FIELD_CONFIG = {
|
||||||
|
"whisper_input_features": MultiModalFieldConfig.batched("audio"),
|
||||||
|
"feature_attention_mask": MultiModalFieldConfig.batched("audio"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class KimiAudioMultiModalDataParser(MultiModalDataParser):
|
||||||
|
"""Custom data parser for Kimi-Audio multimodal data."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
# Whisper expects 16kHz audio
|
||||||
|
super().__init__(target_sr=16000, **kwargs)
|
||||||
|
|
||||||
|
def _parse_audio_data(
|
||||||
|
self,
|
||||||
|
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
|
||||||
|
) -> ModalityDataItems[Any, Any] | None:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return DictEmbeddingItems(
|
||||||
|
data,
|
||||||
|
modality="audio",
|
||||||
|
required_fields={"whisper_input_features", "feature_attention_mask"},
|
||||||
|
fields_factory=lambda hf_inputs: _KIMIAUDIO_FIELD_CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
|
return super()._parse_audio_data(data)
|
||||||
|
|
||||||
|
|
||||||
|
class KimiAudioMultiModalProcessor(BaseMultiModalProcessor[KimiAudioProcessingInfo]):
|
||||||
|
"""vLLM multi-modal processor wrapper for Kimi-Audio."""
|
||||||
|
|
||||||
|
def _call_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
mm_data: Mapping[str, object],
|
||||||
|
mm_kwargs: Mapping[str, object],
|
||||||
|
tok_kwargs: Mapping[str, object],
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""Call the HuggingFace processor."""
|
||||||
|
# Convert mm_data format: {'audios': [...]} -> {'audio': ...}
|
||||||
|
mm_data = dict(mm_data)
|
||||||
|
audios = mm_data.pop("audios", [])
|
||||||
|
|
||||||
|
# Convert audio format: [(array, sr), ...] -> [array, ...]
|
||||||
|
# KimiAudioProcessor expects raw numpy arrays
|
||||||
|
if audios:
|
||||||
|
audio_arrays = []
|
||||||
|
for aud in audios:
|
||||||
|
if isinstance(aud, (tuple, list)) and len(aud) == 2:
|
||||||
|
# Format: (audio_array, sampling_rate)
|
||||||
|
audio_arrays.append(aud[0])
|
||||||
|
elif isinstance(aud, np.ndarray):
|
||||||
|
audio_arrays.append(aud)
|
||||||
|
else:
|
||||||
|
audio_arrays.append(aud)
|
||||||
|
mm_data["audio"] = audio_arrays
|
||||||
|
|
||||||
|
# Use the context's call_hf_processor for proper handling
|
||||||
|
return self.info.ctx.call_hf_processor(
|
||||||
|
self.info.get_hf_processor(**mm_kwargs),
|
||||||
|
dict(text=prompt, **mm_data),
|
||||||
|
dict(**mm_kwargs, **tok_kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""Get multi-modal field configuration."""
|
||||||
|
return _KIMIAUDIO_FIELD_CONFIG
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items,
|
||||||
|
hf_processor_mm_kwargs,
|
||||||
|
out_mm_kwargs,
|
||||||
|
) -> Sequence[PromptReplacement]:
|
||||||
|
"""Get prompt updates for audio tokens."""
|
||||||
|
# Get audio feature lengths from processed output
|
||||||
|
out_mm_data = out_mm_kwargs.get_data()
|
||||||
|
feature_attention_mask = out_mm_data.get("feature_attention_mask")
|
||||||
|
|
||||||
|
if feature_attention_mask is not None:
|
||||||
|
audio_output_lens = _get_feat_extract_output_lengths(
|
||||||
|
feature_attention_mask.sum(-1)
|
||||||
|
)
|
||||||
|
audio_output_lengths = audio_output_lens.tolist()
|
||||||
|
else:
|
||||||
|
audio_output_lengths = []
|
||||||
|
|
||||||
|
def get_replacement_kimiaudio(item_idx: int):
|
||||||
|
num_features = (
|
||||||
|
audio_output_lengths[item_idx]
|
||||||
|
if item_idx < len(audio_output_lengths)
|
||||||
|
else 376
|
||||||
|
)
|
||||||
|
if num_features == 0:
|
||||||
|
num_features = 376 # Default Kimi-Audio sequence length
|
||||||
|
# Return the placeholder token ID repeated num_features times
|
||||||
|
return [KimiAudioProcessor.KIMIA_TEXT_BLANK] * num_features
|
||||||
|
|
||||||
|
# Use the token ID as target (as a list)
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality="audio",
|
||||||
|
target=[KimiAudioProcessor.KIMIA_TEXT_BLANK],
|
||||||
|
replacement=get_replacement_kimiaudio,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Model Definition
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class KimiAudioMultiModalProjector(nn.Module):
|
||||||
|
"""Projects Whisper features to LLM embedding space.
|
||||||
|
|
||||||
|
Kimi-Audio VQ-Adaptor architecture:
|
||||||
|
Custom Whisper (5120) → Linear[5120→3584] → Linear[3584→3584] → LayerNorm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
whisper_dim: int = 5120, # Kimi-Audio custom Whisper encoder dim
|
||||||
|
llm_dim: int = 3584,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.whisper_dim = whisper_dim
|
||||||
|
self.llm_dim = llm_dim
|
||||||
|
|
||||||
|
# VQ-Adaptor layers (exact checkpoint structure)
|
||||||
|
# layers.0: Linear[5120 → 3584]
|
||||||
|
self.vq_adaptor_layers_0 = nn.Linear(whisper_dim, llm_dim)
|
||||||
|
# layers.3: Linear[3584 → 3584]
|
||||||
|
self.vq_adaptor_layers_3 = nn.Linear(llm_dim, llm_dim)
|
||||||
|
# layers.4: LayerNorm[3584]
|
||||||
|
self.vq_adaptor_layers_4 = nn.LayerNorm(llm_dim)
|
||||||
|
|
||||||
|
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
||||||
|
# Project: [B, T, 5120] → [B, T, 3584]
|
||||||
|
hidden = self.vq_adaptor_layers_0(audio_features)
|
||||||
|
hidden = torch.nn.functional.gelu(hidden)
|
||||||
|
hidden = self.vq_adaptor_layers_3(hidden)
|
||||||
|
hidden = self.vq_adaptor_layers_4(hidden)
|
||||||
|
return hidden
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
KimiAudioMultiModalProcessor,
|
||||||
|
info=KimiAudioProcessingInfo,
|
||||||
|
dummy_inputs=KimiAudioDummyInputsBuilder,
|
||||||
|
)
|
||||||
|
class KimiAudioForConditionalGeneration(
|
||||||
|
nn.Module,
|
||||||
|
SupportsMultiModal,
|
||||||
|
SupportsPP,
|
||||||
|
SupportsTranscription,
|
||||||
|
):
|
||||||
|
"""Kimi-Audio model for ASR transcription."""
|
||||||
|
|
||||||
|
# Kimi-Audio supports a subset of Whisper's supported languages
|
||||||
|
supported_languages: ClassVar[Mapping[str, str]] = {
|
||||||
|
k: ISO639_1_SUPPORTED_LANGS[k]
|
||||||
|
for k in ["zh", "en", "ja", "ko", "de", "fr", "es", "it", "pt", "ru", "ar"]
|
||||||
|
}
|
||||||
|
supports_transcription: ClassVar[Literal[True]] = True
|
||||||
|
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_prefix={
|
||||||
|
# Audio projector (VQ-Adaptor)
|
||||||
|
"model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.",
|
||||||
|
"model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.",
|
||||||
|
"model.vq_adaptor.layers.4.": "multi_modal_projector.vq_adaptor_layers_4.",
|
||||||
|
# Language model
|
||||||
|
"model.layers.": "language_model.model.layers.",
|
||||||
|
# Embeddings and output
|
||||||
|
"model.embed_tokens.": "language_model.model.embed_tokens.",
|
||||||
|
"model.norm.": "language_model.model.norm.",
|
||||||
|
"lm_head.": "language_model.lm_head.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio placeholder token sequence
|
||||||
|
AUDIO_PLACEHOLDER = "<|im_media_begin|><|im_kimia_text_blank|><|im_media_end|>"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||||
|
return cls.AUDIO_PLACEHOLDER if modality.startswith("audio") else None
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.config = vllm_config.model_config.hf_config
|
||||||
|
self.quant_config = vllm_config.quant_config
|
||||||
|
self.multimodal_config = vllm_config.model_config.multimodal_config
|
||||||
|
self.model_path = vllm_config.model_config.model
|
||||||
|
|
||||||
|
self.audio_tower = KimiAudioWhisperEncoder(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.multi_modal_projector = KimiAudioMultiModalProjector(
|
||||||
|
whisper_dim=getattr(self.config, "kimia_adaptor_input_dim", 5120),
|
||||||
|
llm_dim=self.config.hidden_size,
|
||||||
|
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.language_model = init_vllm_registered_model(
|
||||||
|
vllm_config=vllm_config.with_hf_config(
|
||||||
|
self.config, architectures=["Qwen2ForCausalLM"]
|
||||||
|
),
|
||||||
|
prefix=maybe_prefix(prefix, "language_model"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logits_processor = LogitsProcessor(
|
||||||
|
self.config.vocab_size,
|
||||||
|
self.config.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.language_model.make_empty_intermediate_tensors
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_and_validate_audio_input(
|
||||||
|
self, **kwargs: object
|
||||||
|
) -> dict[str, torch.Tensor] | None:
|
||||||
|
whisper_input_features = kwargs.pop("whisper_input_features", None)
|
||||||
|
if whisper_input_features is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {"whisper_input_features": whisper_input_features}
|
||||||
|
|
||||||
|
def _process_audio_input(
|
||||||
|
self, audio_input: dict[str, torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
input_features = audio_input["whisper_input_features"]
|
||||||
|
|
||||||
|
# KimiAudioWhisperEncoder expects list of tensors
|
||||||
|
if input_features.dim() == 3:
|
||||||
|
input_features = input_features.unbind(dim=0)
|
||||||
|
|
||||||
|
# Run through Whisper encoder
|
||||||
|
audio_features = self.audio_tower(input_features)
|
||||||
|
|
||||||
|
# Reshape for 4x downsampling (Whisper outputs at 50Hz, need 12.5Hz)
|
||||||
|
B, T, D = audio_features.shape
|
||||||
|
if T % 4 != 0:
|
||||||
|
pad_len = 4 - (T % 4)
|
||||||
|
audio_features = torch.nn.functional.pad(audio_features, (0, 0, 0, pad_len))
|
||||||
|
T = audio_features.shape[1] # Update T after padding
|
||||||
|
|
||||||
|
audio_features = audio_features.reshape(B, T // 4, D * 4)
|
||||||
|
|
||||||
|
# Project to LLM dimension
|
||||||
|
audio_embeds = self.multi_modal_projector(audio_features)
|
||||||
|
return audio_embeds
|
||||||
|
|
||||||
|
def embed_multimodal(self, **kwargs: object) -> list[torch.Tensor] | None:
|
||||||
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
|
if audio_input is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
audio_embeds = self._process_audio_input(audio_input)
|
||||||
|
|
||||||
|
# audio_embeds shape: [batch_size, seq_len, hidden_dim]
|
||||||
|
# Return as list of 2D tensors, one per batch item
|
||||||
|
if audio_embeds.dim() == 3:
|
||||||
|
# Unbind batch dimension: [B, T, D] -> list of B tensors [T, D]
|
||||||
|
return list(audio_embeds.unbind(dim=0))
|
||||||
|
else:
|
||||||
|
# Single sample: [T, D] -> wrap in list
|
||||||
|
return [audio_embeds]
|
||||||
|
|
||||||
|
def embed_input_ids(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: tuple[torch.Tensor, ...] | None = None,
|
||||||
|
*,
|
||||||
|
is_multimodal: torch.Tensor | None = None,
|
||||||
|
handle_oov_mm_token: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Embed input IDs and fuse with audio embeddings.
|
||||||
|
|
||||||
|
Kimi-Audio fusion: inputs_embeds = (text_emb + audio_emb) × √2
|
||||||
|
|
||||||
|
For PP compatibility, we use the is_multimodal mask from vLLM engine
|
||||||
|
which is correctly computed per pipeline stage.
|
||||||
|
"""
|
||||||
|
# Get text embeddings
|
||||||
|
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
# is_multimodal must be provided for PP to work correctly
|
||||||
|
if is_multimodal is None or not is_multimodal.any():
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
# multimodal_embeddings[0] contains audio embeddings
|
||||||
|
audio_embeds = multimodal_embeddings[0]
|
||||||
|
|
||||||
|
# Handle different tensor structures
|
||||||
|
if isinstance(audio_embeds, (list, tuple)):
|
||||||
|
audio_embeds = torch.cat(audio_embeds, dim=0)
|
||||||
|
elif audio_embeds.dim() == 3:
|
||||||
|
audio_embeds = audio_embeds.reshape(-1, audio_embeds.shape[-1])
|
||||||
|
|
||||||
|
# In PP, audio_embeds count should match is_multimodal.sum()
|
||||||
|
# For now, use embeddings sequentially
|
||||||
|
# (works for non-PP, PP needs vLLM infra fix)
|
||||||
|
num_mm_tokens = is_multimodal.sum().item()
|
||||||
|
num_audio_embeds = audio_embeds.shape[0]
|
||||||
|
|
||||||
|
# Use the minimum of available embeddings and positions
|
||||||
|
# This ensures we don't access out-of-bounds
|
||||||
|
num_to_use = min(num_audio_embeds, num_mm_tokens)
|
||||||
|
|
||||||
|
# Get positions for the tokens we'll actually process
|
||||||
|
mm_positions = is_multimodal.nonzero(as_tuple=True)[0]
|
||||||
|
actual_mm_mask = torch.zeros_like(is_multimodal)
|
||||||
|
actual_mm_mask[mm_positions[:num_to_use]] = True
|
||||||
|
|
||||||
|
# Use corresponding embeddings
|
||||||
|
used_audio_embeds = audio_embeds[:num_to_use]
|
||||||
|
|
||||||
|
# Save text embeddings at multimodal positions
|
||||||
|
text_at_mm_positions = inputs_embeds[actual_mm_mask].clone()
|
||||||
|
|
||||||
|
# Replace text with audio at multimodal positions
|
||||||
|
inputs_embeds[actual_mm_mask] = used_audio_embeds.to(dtype=inputs_embeds.dtype)
|
||||||
|
|
||||||
|
# Apply Kimi-Audio's unique fusion formula: (text + audio) × √2
|
||||||
|
inputs_embeds[actual_mm_mask] = (
|
||||||
|
inputs_embeds[actual_mm_mask] + text_at_mm_positions
|
||||||
|
) * (2**0.5)
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor | None,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> torch.Tensor | IntermediateTensors:
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
hidden_states = self.language_model.model(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata | None = None,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
logits = self.logits_processor(
|
||||||
|
self.language_model.lm_head, hidden_states, sampling_metadata
|
||||||
|
)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
"""Load weights, skipping MIMO layers (TTS-only) for ASR."""
|
||||||
|
# Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
|
||||||
|
skipped_patterns = [
|
||||||
|
"mimo_layers.",
|
||||||
|
"mimo_output.",
|
||||||
|
"mimo_norm.",
|
||||||
|
"audio_decoder.",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Filter weights
|
||||||
|
filtered_weights = [
|
||||||
|
(name, param)
|
||||||
|
for name, param in weights
|
||||||
|
if not any(pattern in name for pattern in skipped_patterns)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Separate main weights (non-Whisper) from Whisper weights
|
||||||
|
main_weights = [
|
||||||
|
(name, param)
|
||||||
|
for name, param in filtered_weights
|
||||||
|
if not name.startswith("audio_tower.")
|
||||||
|
]
|
||||||
|
|
||||||
|
# Load main model weights (LLM + projector) with mapper
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
# Load Whisper encoder weights from subfolder
|
||||||
|
whisper_path = os.path.join(
|
||||||
|
self.model_path, f"{KIMIA_WHISPER_SUBFOLDER}/model.safetensors"
|
||||||
|
)
|
||||||
|
if os.path.exists(whisper_path):
|
||||||
|
whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
|
||||||
|
loaded.update(whisper_loaded)
|
||||||
|
|
||||||
|
return loaded
|
||||||
|
|
||||||
|
def _load_whisper_weights_from_file(self, whisper_path: str) -> set[str]:
|
||||||
|
"""Load Whisper encoder weights from safetensors file with transformations."""
|
||||||
|
if not os.path.exists(whisper_path):
|
||||||
|
return set()
|
||||||
|
|
||||||
|
# Step 1: Load raw weights from safetensors file
|
||||||
|
whisper_weights = []
|
||||||
|
with safe_open(whisper_path, framework="pt") as f:
|
||||||
|
for key in f.keys(): # noqa: SIM118
|
||||||
|
if key.startswith("model.encoder.") and "embed_positions" not in key:
|
||||||
|
new_key = key.replace("model.encoder.", "")
|
||||||
|
whisper_weights.append((new_key, f.get_tensor(key)))
|
||||||
|
|
||||||
|
# Step 2: Apply fc → mlp mapping using WeightsMapper
|
||||||
|
fc_mapper = WeightsMapper(
|
||||||
|
orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
|
||||||
|
)
|
||||||
|
whisper_mapped = list(fc_mapper.apply(whisper_weights))
|
||||||
|
|
||||||
|
# Step 3: Apply Q/K/V fusion manually
|
||||||
|
stacked_params_mapping = [
|
||||||
|
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
||||||
|
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
||||||
|
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||||
|
]
|
||||||
|
|
||||||
|
params_dict = dict(self.audio_tower.named_parameters())
|
||||||
|
whisper_loaded: set[str] = set()
|
||||||
|
|
||||||
|
for name, loaded_weight in whisper_mapped:
|
||||||
|
fused = False
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
fused_name = name.replace(weight_name, param_name)
|
||||||
|
if fused_name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[fused_name]
|
||||||
|
param.weight_loader(param, loaded_weight, shard_id)
|
||||||
|
whisper_loaded.add(f"audio_tower.{fused_name}")
|
||||||
|
fused = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not fused:
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
whisper_loaded.add(f"audio_tower.{name}")
|
||||||
|
|
||||||
|
# Add embed_positions which is initialized randomly
|
||||||
|
whisper_loaded.add("audio_tower.embed_positions.weight")
|
||||||
|
|
||||||
|
return whisper_loaded
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_speech_to_text_config(
|
||||||
|
cls, model_config: ModelConfig, task_type: str
|
||||||
|
) -> SpeechToTextConfig:
|
||||||
|
"""Get speech-to-text config with custom processor."""
|
||||||
|
# Load feature extractor for config values
|
||||||
|
feature_extractor = cached_feature_extractor_from_config(
|
||||||
|
model_config,
|
||||||
|
subfolder=KIMIA_WHISPER_SUBFOLDER,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SpeechToTextConfig(
|
||||||
|
max_audio_clip_s=feature_extractor.chunk_length,
|
||||||
|
sample_rate=feature_extractor.sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_generation_prompt(
|
||||||
|
cls,
|
||||||
|
audio: np.ndarray,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
stt_config: SpeechToTextConfig,
|
||||||
|
language: str | None,
|
||||||
|
task_type: Literal["transcribe", "translate"],
|
||||||
|
request_prompt: str,
|
||||||
|
to_language: str | None,
|
||||||
|
) -> PromptType:
|
||||||
|
tokenizer = cached_get_tokenizer(
|
||||||
|
model_config.tokenizer,
|
||||||
|
tokenizer_cls=KimiAudioTokenizer,
|
||||||
|
tokenizer_mode=model_config.tokenizer_mode,
|
||||||
|
revision=model_config.tokenizer_revision,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if task_type not in ("transcribe", "translate"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported task_type '{task_type}'. "
|
||||||
|
"Supported task types are 'transcribe' and 'translate'."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Incorporate request_prompt as context/instruction if provided
|
||||||
|
user_content = (
|
||||||
|
f"{request_prompt}\n{cls.AUDIO_PLACEHOLDER}"
|
||||||
|
if request_prompt
|
||||||
|
else cls.AUDIO_PLACEHOLDER
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
f"<|im_kimia_user_msg_start|>{user_content}"
|
||||||
|
f"<|im_msg_end|><|im_kimia_assistant_msg_start|>"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_token_ids = tokenizer.encode(prompt)
|
||||||
|
|
||||||
|
return TokensPrompt(
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
multi_modal_data={"audio": audio},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def post_process_output(cls, text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
return text.strip()
|
||||||
@@ -421,6 +421,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
|
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
|
||||||
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
|
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
|
||||||
"KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"), # noqa: E501
|
"KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"), # noqa: E501
|
||||||
|
"MoonshotKimiaForCausalLM": ("kimi_audio", "KimiAudioForConditionalGeneration"), # noqa: E501
|
||||||
"LightOnOCRForConditionalGeneration": (
|
"LightOnOCRForConditionalGeneration": (
|
||||||
"lightonocr",
|
"lightonocr",
|
||||||
"LightOnOCRForConditionalGeneration",
|
"LightOnOCRForConditionalGeneration",
|
||||||
|
|||||||
49
vllm/renderers/kimi_audio.py
Normal file
49
vllm/renderers/kimi_audio.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
|
||||||
|
from vllm.tokenizers.registry import get_tokenizer
|
||||||
|
|
||||||
|
from .hf import HfRenderer, HfTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class KimiAudioRenderer(HfRenderer):
|
||||||
|
"""Renderer for Kimi-Audio models.
|
||||||
|
|
||||||
|
This renderer uses HfRenderer internally with a custom TikToken tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config( # type: ignore[override]
|
||||||
|
cls,
|
||||||
|
config: VllmConfig,
|
||||||
|
tokenizer_kwargs: dict[str, Any],
|
||||||
|
) -> "HfRenderer":
|
||||||
|
"""Create an HfRenderer instance for Kimi-Audio models."""
|
||||||
|
model_config = config.model_config
|
||||||
|
if model_config.skip_tokenizer_init:
|
||||||
|
tokenizer = None
|
||||||
|
else:
|
||||||
|
# Extract tokenizer_name from kwargs (already processed by
|
||||||
|
# tokenizer_args_from_config for ModelScope/GGUF/etc)
|
||||||
|
tokenizer_name = tokenizer_kwargs.pop(
|
||||||
|
"tokenizer_name", model_config.tokenizer
|
||||||
|
)
|
||||||
|
# Remove tokenizer_cls from kwargs to avoid duplicate argument
|
||||||
|
tokenizer_kwargs = {
|
||||||
|
k: v for k, v in tokenizer_kwargs.items() if k != "tokenizer_cls"
|
||||||
|
}
|
||||||
|
# Use get_tokenizer directly instead of cached_get_tokenizer
|
||||||
|
# (KimiAudioTokenizer doesn't work with get_cached_tokenizer)
|
||||||
|
tokenizer = cast(
|
||||||
|
HfTokenizer,
|
||||||
|
get_tokenizer(
|
||||||
|
tokenizer_name,
|
||||||
|
tokenizer_cls=KimiAudioTokenizer, # type: ignore[arg-type]
|
||||||
|
**tokenizer_kwargs,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return HfRenderer(config, tokenizer)
|
||||||
@@ -19,6 +19,7 @@ _VLLM_RENDERERS = {
|
|||||||
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
|
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
|
||||||
"hf": ("hf", "HfRenderer"),
|
"hf": ("hf", "HfRenderer"),
|
||||||
"grok2": ("grok2", "Grok2Renderer"),
|
"grok2": ("grok2", "Grok2Renderer"),
|
||||||
|
"kimi_audio": ("kimi_audio", "KimiAudioRenderer"),
|
||||||
"mistral": ("mistral", "MistralRenderer"),
|
"mistral": ("mistral", "MistralRenderer"),
|
||||||
"qwen_vl": ("qwen_vl", "QwenVLRenderer"),
|
"qwen_vl": ("qwen_vl", "QwenVLRenderer"),
|
||||||
"terratorch": ("terratorch", "TerratorchRenderer"),
|
"terratorch": ("terratorch", "TerratorchRenderer"),
|
||||||
@@ -74,10 +75,18 @@ RENDERER_REGISTRY = RendererRegistry(
|
|||||||
|
|
||||||
def renderer_from_config(config: "VllmConfig", **kwargs):
|
def renderer_from_config(config: "VllmConfig", **kwargs):
|
||||||
model_config = config.model_config
|
model_config = config.model_config
|
||||||
|
|
||||||
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
|
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
|
||||||
model_config, **kwargs
|
model_config, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Override tokenizer_mode for Kimi-Audio models
|
||||||
|
if model_config.architecture == "MoonshotKimiaForCausalLM":
|
||||||
|
tokenizer_mode = "kimi_audio"
|
||||||
|
# Update model_config so other components (e.g., multimodal registry)
|
||||||
|
# also use the correct tokenizer mode
|
||||||
|
model_config.tokenizer_mode = "kimi_audio"
|
||||||
|
|
||||||
if (
|
if (
|
||||||
model_config.tokenizer_mode == "auto"
|
model_config.tokenizer_mode == "auto"
|
||||||
and model_config.model_impl == "terratorch"
|
and model_config.model_impl == "terratorch"
|
||||||
|
|||||||
410
vllm/tokenizers/kimi_audio.py
Normal file
410
vllm/tokenizers/kimi_audio.py
Normal file
@@ -0,0 +1,410 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Tokenizer for Kimi-Audio using TikToken."""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, overload
|
||||||
|
|
||||||
|
import pybase64
|
||||||
|
import tiktoken
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from transformers import AddedToken, BatchEncoding
|
||||||
|
from transformers.utils import chat_template_utils as hf_chat_utils
|
||||||
|
|
||||||
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.tokenizers.protocol import TokenizerLike
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_tiktoken_encoding(
|
||||||
|
vocab_file: Path, special_tokens: dict[str, int]
|
||||||
|
) -> tuple[Any, dict[str, int]]:
|
||||||
|
"""Load TikToken encoding from vocab file."""
|
||||||
|
mergeable_ranks: dict[bytes, int] = {}
|
||||||
|
with open(vocab_file, encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parts = line.split()
|
||||||
|
if len(parts) == 2:
|
||||||
|
token_b64 = parts[0]
|
||||||
|
rank = int(parts[1])
|
||||||
|
token_bytes = pybase64.b64decode(token_b64)
|
||||||
|
mergeable_ranks[token_bytes] = rank
|
||||||
|
|
||||||
|
tokenizer = tiktoken.Encoding(
|
||||||
|
name=str(vocab_file),
|
||||||
|
pat_str=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|"""
|
||||||
|
r""" ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
|
||||||
|
mergeable_ranks=mergeable_ranks,
|
||||||
|
special_tokens=special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return tokenizer, special_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class KimiAudioTokenizer(TokenizerLike):
|
||||||
|
"""TikToken tokenizer for Kimi-Audio."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
path_or_repo_id: str | Path,
|
||||||
|
*args,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
revision: str | None = None,
|
||||||
|
download_dir: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "KimiAudioTokenizer":
|
||||||
|
if args:
|
||||||
|
logger.debug_once("Ignoring extra positional args for KimiAudioTokenizer.")
|
||||||
|
|
||||||
|
path = Path(path_or_repo_id)
|
||||||
|
if path.is_file():
|
||||||
|
vocab_file = path
|
||||||
|
elif path.is_dir():
|
||||||
|
vocab_file = path / "tiktoken.model"
|
||||||
|
if not vocab_file.is_file():
|
||||||
|
vocab_file = path / "tokenizer.model"
|
||||||
|
else:
|
||||||
|
# Download from HuggingFace Hub
|
||||||
|
repo_id = str(path_or_repo_id)
|
||||||
|
|
||||||
|
# Try to download tiktoken.model or tokenizer.model
|
||||||
|
try:
|
||||||
|
vocab_path = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename="tiktoken.model",
|
||||||
|
revision=revision,
|
||||||
|
local_dir=download_dir,
|
||||||
|
)
|
||||||
|
vocab_file = Path(vocab_path)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
vocab_path = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename="tokenizer.model",
|
||||||
|
revision=revision,
|
||||||
|
local_dir=download_dir,
|
||||||
|
)
|
||||||
|
vocab_file = Path(vocab_path)
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find tiktoken.model or tokenizer.model in {repo_id}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# Also download tokenizer_config.json if available
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename="tokenizer_config.json",
|
||||||
|
revision=revision,
|
||||||
|
local_dir=download_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not vocab_file.is_file():
|
||||||
|
raise FileNotFoundError(f"tiktoken.model not found at {vocab_file}.")
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
vocab_file=vocab_file,
|
||||||
|
name_or_path=str(path_or_repo_id),
|
||||||
|
truncation_side=kwargs.get("truncation_side", "left"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vocab_file: Path,
|
||||||
|
name_or_path: str,
|
||||||
|
truncation_side: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.name_or_path = name_or_path
|
||||||
|
self._truncation_side = truncation_side
|
||||||
|
self._vocab_file = vocab_file
|
||||||
|
|
||||||
|
# Load special tokens from tokenizer_config.json
|
||||||
|
special_tokens: dict[str, int] = {}
|
||||||
|
tokenizer_config = vocab_file.parent / "tokenizer_config.json"
|
||||||
|
if tokenizer_config.is_file():
|
||||||
|
with open(tokenizer_config, encoding="utf-8") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
# Extract special tokens from added_tokens_decoder
|
||||||
|
added_tokens = config.get("added_tokens_decoder", {})
|
||||||
|
for token_id_str, token_info in added_tokens.items():
|
||||||
|
token_id = int(token_id_str)
|
||||||
|
content = token_info.get("content", "")
|
||||||
|
if content:
|
||||||
|
special_tokens[content] = token_id
|
||||||
|
|
||||||
|
self._tokenizer, self._special_tokens = _load_tiktoken_encoding(
|
||||||
|
vocab_file, special_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build token <-> ID mappings
|
||||||
|
self._token_to_id: dict[str, int] = {}
|
||||||
|
self._id_to_token: dict[int, str] = {}
|
||||||
|
for token_bytes, token_id in self._tokenizer._mergeable_ranks.items():
|
||||||
|
token_str = token_bytes.decode("utf-8", errors="replace")
|
||||||
|
self._token_to_id[token_str] = token_id
|
||||||
|
self._id_to_token[token_id] = token_str
|
||||||
|
|
||||||
|
# Initialize added_tokens_decoder before adding special tokens
|
||||||
|
self._added_tokens_decoder: dict[int, Any] = {}
|
||||||
|
|
||||||
|
# Add Kimi-Audio special tokens
|
||||||
|
self._add_kimiaudio_special_tokens()
|
||||||
|
|
||||||
|
# Set default special token IDs (will be updated when special tokens are added)
|
||||||
|
self._bos_token_id = 151643 # Kimi-Audio BOS
|
||||||
|
self._eos_token_id = 151644 # Kimi-Audio EOS
|
||||||
|
self._pad_token_id = self._eos_token_id
|
||||||
|
self._unk_token_id = self._pad_token_id
|
||||||
|
|
||||||
|
self._max_chars_per_token = max(
|
||||||
|
(len(tok) for tok in self._token_to_id), default=10
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_kimiaudio_special_tokens(self) -> None:
|
||||||
|
"""Add Kimi-Audio special tokens to the tokenizer."""
|
||||||
|
# Tokens should already be in self._special_tokens from tokenizer_config.json
|
||||||
|
# Just add them to added_tokens_decoder for compatibility
|
||||||
|
kimiaudio_special_tokens = {
|
||||||
|
"<|im_media_begin|>": 151661,
|
||||||
|
"<|im_media_end|>": 151663,
|
||||||
|
"<|im_kimia_text_blank|>": 151666,
|
||||||
|
"<|im_msg_end|>": 151645,
|
||||||
|
"<|im_kimia_user_msg_start|>": 151670,
|
||||||
|
"<|im_kimia_assistant_msg_start|>": 151671,
|
||||||
|
}
|
||||||
|
|
||||||
|
for token_str, token_id in kimiaudio_special_tokens.items():
|
||||||
|
# Only add if not already present
|
||||||
|
if token_id not in self._added_tokens_decoder:
|
||||||
|
self._added_tokens_decoder[token_id] = AddedToken(
|
||||||
|
token_str, single_word=True, normalized=False, special=True
|
||||||
|
)
|
||||||
|
# Also ensure it's in _token_to_id and _id_to_token
|
||||||
|
if token_str not in self._token_to_id:
|
||||||
|
self._token_to_id[token_str] = token_id
|
||||||
|
if token_id not in self._id_to_token:
|
||||||
|
self._id_to_token[token_id] = token_str
|
||||||
|
|
||||||
|
def num_special_tokens_to_add(self) -> int:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_special_tokens(self) -> list[str]:
|
||||||
|
return list(self._added_tokens_decoder.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_special_ids(self) -> list[int]:
|
||||||
|
return list(self._added_tokens_decoder.keys())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bos_token_id(self) -> int:
|
||||||
|
return self._bos_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_token_id(self) -> int:
|
||||||
|
return self._eos_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad_token_id(self) -> int:
|
||||||
|
return self._pad_token_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_fast(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self._tokenizer.n_vocab
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_token_id(self) -> int:
|
||||||
|
return self._tokenizer.n_vocab - 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_chars_per_token(self) -> int:
|
||||||
|
return self._max_chars_per_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def truncation_side(self) -> str:
|
||||||
|
return self._truncation_side
|
||||||
|
|
||||||
|
@property
|
||||||
|
def added_tokens_decoder(self) -> dict[int, Any]:
|
||||||
|
return self._added_tokens_decoder
|
||||||
|
|
||||||
|
@added_tokens_decoder.setter
|
||||||
|
def added_tokens_decoder(self, value: dict[int, Any]) -> None:
|
||||||
|
"""Set added tokens decoder and update special token IDs."""
|
||||||
|
self._added_tokens_decoder = value
|
||||||
|
# Update special token IDs if known tokens are added
|
||||||
|
for token_id, token in value.items():
|
||||||
|
token_str = str(token) if hasattr(token, "__str__") else token
|
||||||
|
if "<|im_kimia_user_msg_start|>" in token_str:
|
||||||
|
self._bos_token_id = token_id
|
||||||
|
elif "<|im_msg_end|>" in token_str or "<|im_end|>" in token_str:
|
||||||
|
self._eos_token_id = token_id
|
||||||
|
|
||||||
|
def get_vocab(self) -> dict[str, int]:
|
||||||
|
return dict(self._token_to_id)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return vocab size for compatibility with HF tokenizer interface."""
|
||||||
|
return self._tokenizer.n_vocab
|
||||||
|
|
||||||
|
def get_added_vocab(self) -> dict[str, int]:
|
||||||
|
return {
|
||||||
|
str(token): token_id
|
||||||
|
for token_id, token in self._added_tokens_decoder.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]:
|
||||||
|
if max_length is None or len(tokens) <= max_length:
|
||||||
|
return tokens
|
||||||
|
if self.truncation_side == "left":
|
||||||
|
return tokens[-max_length:]
|
||||||
|
return tokens[:max_length]
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
truncation: bool | None = None,
|
||||||
|
max_length: int | None = None,
|
||||||
|
add_special_tokens: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[int]:
|
||||||
|
del add_special_tokens
|
||||||
|
# Allow Kimi-Audio special tokens to be encoded
|
||||||
|
tokens = self._tokenizer.encode(
|
||||||
|
text,
|
||||||
|
allowed_special={
|
||||||
|
"<|im_media_begin|>",
|
||||||
|
"<|im_media_end|>",
|
||||||
|
"<|im_kimia_text_blank|>",
|
||||||
|
"<|im_msg_end|>",
|
||||||
|
"<|im_kimia_user_msg_start|>",
|
||||||
|
"<|im_kimia_assistant_msg_start|>",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if truncation:
|
||||||
|
tokens = self._maybe_truncate(tokens, max_length)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
|
||||||
|
"""Decode token IDs to text, optionally skipping special tokens."""
|
||||||
|
if isinstance(ids, int):
|
||||||
|
ids = [ids]
|
||||||
|
if skip_special_tokens:
|
||||||
|
# Skip tokens that are in special_tokens (loaded from config)
|
||||||
|
special_ids = set(self._special_tokens.values())
|
||||||
|
ids = [token_id for token_id in ids if token_id not in special_ids]
|
||||||
|
return self._tokenizer.decode(ids)
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def convert_tokens_to_ids(self, tokens: str) -> int: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
|
||||||
|
if isinstance(tokens, str):
|
||||||
|
return self._token_to_id.get(tokens, self._unk_token_id)
|
||||||
|
return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]
|
||||||
|
|
||||||
|
def convert_ids_to_tokens(
|
||||||
|
self, ids: list[int], skip_special_tokens: bool = False
|
||||||
|
) -> list[str]:
|
||||||
|
tokens = []
|
||||||
|
for token_id in ids:
|
||||||
|
if skip_special_tokens and token_id in self._added_tokens_decoder:
|
||||||
|
continue
|
||||||
|
tokens.append(self._id_to_token.get(token_id, "<|unk|>"))
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||||
|
token_ids = self.convert_tokens_to_ids(tokens)
|
||||||
|
return self.decode(token_ids, skip_special_tokens=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: str | list[str],
|
||||||
|
text_pair: str | None = None,
|
||||||
|
add_special_tokens: bool = True,
|
||||||
|
truncation: bool = False,
|
||||||
|
max_length: int | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchEncoding:
|
||||||
|
if text_pair is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"text_pair is not supported for KimiAudioTokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(text, list):
|
||||||
|
input_ids_batch: list[list[int]] = [
|
||||||
|
self.encode(
|
||||||
|
item,
|
||||||
|
truncation=truncation,
|
||||||
|
max_length=max_length,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
)
|
||||||
|
for item in text
|
||||||
|
]
|
||||||
|
attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch]
|
||||||
|
return BatchEncoding(
|
||||||
|
{"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = self.encode(
|
||||||
|
text,
|
||||||
|
truncation=truncation,
|
||||||
|
max_length=max_length,
|
||||||
|
add_special_tokens=add_special_tokens,
|
||||||
|
)
|
||||||
|
attention_mask = [1] * len(input_ids)
|
||||||
|
return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})
|
||||||
|
|
||||||
|
def get_chat_template(
|
||||||
|
self, chat_template: str | None, tools: list[dict[str, Any]] | None = None
|
||||||
|
) -> str | None:
|
||||||
|
del tools
|
||||||
|
return chat_template
|
||||||
|
|
||||||
|
def apply_chat_template(
|
||||||
|
self,
|
||||||
|
messages: list[ChatCompletionMessageParam] | None = None,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
chat_template: str | None = None,
|
||||||
|
tokenize: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> str | list[int]:
|
||||||
|
# Handle both 'messages' (protocol) and 'conversation' (caller) parameter names
|
||||||
|
conversation = messages if messages is not None else kwargs.get("conversation")
|
||||||
|
if conversation is None:
|
||||||
|
raise ValueError("Either 'messages' or 'conversation' must be provided.")
|
||||||
|
template = self.get_chat_template(chat_template, tools=tools)
|
||||||
|
if template is None:
|
||||||
|
raise ValueError(
|
||||||
|
"No chat template available. Provide `chat_template` explicitly."
|
||||||
|
)
|
||||||
|
# Use render_jinja_template instead of apply_chat_template
|
||||||
|
# Note: render_jinja_template returns ([prompts], [generation_indices])
|
||||||
|
rendered, _ = hf_chat_utils.render_jinja_template(
|
||||||
|
conversation,
|
||||||
|
chat_template=template,
|
||||||
|
tools=tools,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# Extract the first (and usually only) prompt
|
||||||
|
prompt = rendered[0] if rendered else ""
|
||||||
|
if tokenize:
|
||||||
|
return self.encode(prompt, add_special_tokens=False)
|
||||||
|
return prompt
|
||||||
@@ -35,6 +35,7 @@ _VLLM_TOKENIZERS = {
|
|||||||
"deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
|
"deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
|
||||||
"grok2": ("grok2", "Grok2Tokenizer"),
|
"grok2": ("grok2", "Grok2Tokenizer"),
|
||||||
"hf": ("hf", "CachedHfTokenizer"),
|
"hf": ("hf", "CachedHfTokenizer"),
|
||||||
|
"kimi_audio": ("kimi_audio", "KimiAudioTokenizer"),
|
||||||
"mistral": ("mistral", "MistralTokenizer"),
|
"mistral": ("mistral", "MistralTokenizer"),
|
||||||
"qwen_vl": ("qwen_vl", "QwenVLTokenizer"),
|
"qwen_vl": ("qwen_vl", "QwenVLTokenizer"),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
{% set messages = conversations[0] if conversations else [] -%}
|
||||||
|
{% if messages and messages[0]['role'] == 'system' -%}
|
||||||
|
{% set loop_messages = messages[1:] -%}
|
||||||
|
{% else -%}
|
||||||
|
{% set loop_messages = messages -%}
|
||||||
|
{% endif -%}
|
||||||
|
{% for message in loop_messages -%}
|
||||||
|
{% if message['role'] == 'user' -%}
|
||||||
|
<|im_kimia_user_msg_start|>{{ message['content'] }}<|im_msg_end|><|im_kimia_assistant_msg_start|>
|
||||||
|
{%- elif message['role'] == 'assistant' -%}
|
||||||
|
{{ message['content'] }}<|im_kimia_text_eos|>
|
||||||
|
{%- endif -%}
|
||||||
|
{% endfor -%}
|
||||||
@@ -10,23 +10,6 @@ reasons:
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
_CLASS_TO_MODULE: dict[str, str] = {
|
|
||||||
"BagelProcessor": "vllm.transformers_utils.processors.bagel",
|
|
||||||
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
|
|
||||||
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",
|
|
||||||
"FunASRProcessor": "vllm.transformers_utils.processors.funasr",
|
|
||||||
"GLM4VProcessor": "vllm.transformers_utils.processors.glm4v",
|
|
||||||
"HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl",
|
|
||||||
"HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image",
|
|
||||||
"MistralCommonPixtralProcessor": "vllm.transformers_utils.processors.pixtral",
|
|
||||||
"MistralCommonVoxtralProcessor": "vllm.transformers_utils.processors.voxtral",
|
|
||||||
"OvisProcessor": "vllm.transformers_utils.processors.ovis",
|
|
||||||
"Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5",
|
|
||||||
"QwenVLProcessor": "vllm.transformers_utils.processors.qwen_vl",
|
|
||||||
"Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BagelProcessor",
|
"BagelProcessor",
|
||||||
"DeepseekVLV2Processor",
|
"DeepseekVLV2Processor",
|
||||||
@@ -35,6 +18,7 @@ __all__ = [
|
|||||||
"GLM4VProcessor",
|
"GLM4VProcessor",
|
||||||
"HunYuanVLProcessor",
|
"HunYuanVLProcessor",
|
||||||
"HunYuanVLImageProcessor",
|
"HunYuanVLImageProcessor",
|
||||||
|
"KimiAudioProcessor",
|
||||||
"MistralCommonPixtralProcessor",
|
"MistralCommonPixtralProcessor",
|
||||||
"MistralCommonVoxtralProcessor",
|
"MistralCommonVoxtralProcessor",
|
||||||
"OvisProcessor",
|
"OvisProcessor",
|
||||||
@@ -43,6 +27,23 @@ __all__ = [
|
|||||||
"Qwen3ASRProcessor",
|
"Qwen3ASRProcessor",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_CLASS_TO_MODULE: dict[str, str] = {
|
||||||
|
"BagelProcessor": "vllm.transformers_utils.processors.bagel",
|
||||||
|
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
|
||||||
|
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",
|
||||||
|
"FunASRProcessor": "vllm.transformers_utils.processors.funasr",
|
||||||
|
"GLM4VProcessor": "vllm.transformers_utils.processors.glm4v",
|
||||||
|
"HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl",
|
||||||
|
"HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image",
|
||||||
|
"KimiAudioProcessor": "vllm.transformers_utils.processors.kimi_audio",
|
||||||
|
"MistralCommonPixtralProcessor": "vllm.transformers_utils.processors.pixtral",
|
||||||
|
"MistralCommonVoxtralProcessor": "vllm.transformers_utils.processors.voxtral",
|
||||||
|
"OvisProcessor": "vllm.transformers_utils.processors.ovis",
|
||||||
|
"Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5",
|
||||||
|
"QwenVLProcessor": "vllm.transformers_utils.processors.qwen_vl",
|
||||||
|
"Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str):
|
def __getattr__(name: str):
|
||||||
if name in _CLASS_TO_MODULE:
|
if name in _CLASS_TO_MODULE:
|
||||||
|
|||||||
163
vllm/transformers_utils/processors/kimi_audio.py
Normal file
163
vllm/transformers_utils/processors/kimi_audio.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# ruff: noqa
|
||||||
|
# mypy: ignore-errors
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2026 The Moonshot AI team and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Processor for Kimi-Audio ASR model."""
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import AutoFeatureExtractor, BatchFeature, ProcessorMixin
|
||||||
|
from transformers.audio_utils import AudioInput
|
||||||
|
from transformers.tokenization_utils_base import TextInput
|
||||||
|
|
||||||
|
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute output lengths after Whisper feature extraction."""
|
||||||
|
input_lengths_leave = input_lengths % 100
|
||||||
|
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||||
|
output_lengths = (
|
||||||
|
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||||
|
)
|
||||||
|
return output_lengths
|
||||||
|
|
||||||
|
|
||||||
|
class KimiAudioProcessor(ProcessorMixin):
|
||||||
|
r"""
|
||||||
|
Constructs a Kimi-Audio processor.
|
||||||
|
|
||||||
|
[`KimiAudioProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and a tokenizer.
|
||||||
|
See the [`~KimiAudioProcessor.__call__`] and [`~KimiAudioProcessor.decode`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
|
||||||
|
The audio feature extractor.
|
||||||
|
tokenizer ([`PreTrainedTokenizer`], *optional*):
|
||||||
|
The text tokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Required for ProcessorMixin
|
||||||
|
attributes = ["feature_extractor", "tokenizer"]
|
||||||
|
feature_extractor_class = "AutoFeatureExtractor"
|
||||||
|
tokenizer_class = "AutoTokenizer"
|
||||||
|
|
||||||
|
# Special token IDs
|
||||||
|
KIMIA_MEDIA_BEGIN: int = 151661
|
||||||
|
KIMIA_MEDIA_END: int = 151663
|
||||||
|
KIMIA_TEXT_BLANK: int = 151666
|
||||||
|
|
||||||
|
# Audio processing constants
|
||||||
|
AUDIO_SEQ_LEN: int = 376
|
||||||
|
|
||||||
|
def __init__(self, feature_extractor=None, tokenizer=None, **kwargs):
|
||||||
|
# Pass feature_extractor and tokenizer to parent ProcessorMixin
|
||||||
|
super().__init__(
|
||||||
|
feature_extractor=feature_extractor,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_argument_for_proper_class(self, attribute_name: str, argument: Any):
|
||||||
|
"""Override to skip class validation for custom tokenizer."""
|
||||||
|
# Skip validation for tokenizer since KimiAudioTokenizer doesn't inherit
|
||||||
|
# from PreTrainedTokenizerBase but is compatible
|
||||||
|
if attribute_name == "tokenizer" and argument is not None:
|
||||||
|
return
|
||||||
|
# For other attributes, use default validation
|
||||||
|
super().check_argument_for_proper_class(attribute_name, argument)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: TextInput = None,
|
||||||
|
audio: AudioInput = None,
|
||||||
|
return_tensors: str = "pt",
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
Main method to prepare for the model one or several sequences(s) and audio(s).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (`str`, `List[str]`):
|
||||||
|
The sequence or batch of sequences to be encoded.
|
||||||
|
audio (`np.ndarray`, `List[np.ndarray]`):
|
||||||
|
The audio or batch of audio to be prepared. Each audio can be a NumPy array.
|
||||||
|
return_tensors (`str`):
|
||||||
|
The type of tensors to return ("pt", "np", etc.)
|
||||||
|
"""
|
||||||
|
if text is None:
|
||||||
|
raise ValueError("You need to specify either a `text` input to process.")
|
||||||
|
|
||||||
|
# Process audio if provided
|
||||||
|
if audio is not None:
|
||||||
|
# Ensure audio is a list
|
||||||
|
if isinstance(audio, np.ndarray):
|
||||||
|
audio = [audio]
|
||||||
|
|
||||||
|
# Pad audio to hop length (required by WhisperFeatureExtractor)
|
||||||
|
hop_length = self.feature_extractor.hop_length
|
||||||
|
padded_audio = []
|
||||||
|
for aud in audio:
|
||||||
|
length = aud.shape[-1]
|
||||||
|
if length % hop_length != 0:
|
||||||
|
pad_length = hop_length - (length % hop_length)
|
||||||
|
aud = np.pad(
|
||||||
|
aud, (0, pad_length), mode="constant", constant_values=0
|
||||||
|
)
|
||||||
|
padded_audio.append(aud)
|
||||||
|
|
||||||
|
# Use feature_extractor directly like Qwen3ASR does
|
||||||
|
audio_inputs = self.feature_extractor(
|
||||||
|
padded_audio,
|
||||||
|
sampling_rate=16000,
|
||||||
|
padding=True,
|
||||||
|
return_attention_mask=True,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
)
|
||||||
|
# Rename to match Kimi-Audio expectations
|
||||||
|
if "input_features" in audio_inputs:
|
||||||
|
audio_inputs["whisper_input_features"] = audio_inputs.pop(
|
||||||
|
"input_features"
|
||||||
|
)
|
||||||
|
if "attention_mask" in audio_inputs:
|
||||||
|
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
|
||||||
|
"attention_mask"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
audio_inputs = {}
|
||||||
|
|
||||||
|
# Handle text input - can be string or token IDs from vLLM processor
|
||||||
|
if isinstance(text, list) and len(text) > 0 and isinstance(text[0], int):
|
||||||
|
# Text is already token IDs (from vLLM processor) - just wrap
|
||||||
|
text_inputs = {"input_ids": torch.tensor([text], dtype=torch.long)}
|
||||||
|
else:
|
||||||
|
# Text is string - tokenize
|
||||||
|
if not isinstance(text, list):
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
text_inputs = self.tokenizer(
|
||||||
|
text, return_tensors=return_tensors, padding=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return BatchFeature(
|
||||||
|
data={**text_inputs, **audio_inputs},
|
||||||
|
tensor_type=return_tensors,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user