[Model] Add support for moonshotai/Kimi-Audio-7B-Instruct (#36127)

Signed-off-by: tunglinwood <tunglinwood@gmail.com>
Signed-off-by: tunglinwood <tomwu.tunglin@gmail.com>
Signed-off-by: tunglinwood <113751333+tunglinwood@users.noreply.github.com>
This commit is contained in:
tunglinwood
2026-03-11 12:24:48 +08:00
committed by GitHub
parent a197eda9c3
commit 42fadebecb
14 changed files with 1446 additions and 29 deletions

View File

@@ -713,8 +713,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `KananaVForConditionalGeneration` | Kanana-V | T + I<sup>+</sup> | `kakaocorp/kanana-1.5-v-3b-instruct`, etc. | | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ |
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
| `KimiAudioForConditionalGeneration` | Kimi-Audio | T + A<sup>+</sup> | `moonshotai/Kimi-Audio-7B-Instruct` | | ✅︎ |
| `KimiK25ForConditionalGeneration` | Kimi-K2.5 | T + I<sup>+</sup> | `moonshotai/Kimi-K2.5` | | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I<sup>+</sup> | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
| `Lfm2VlForConditionalGeneration` | LFM2-VL | T + I<sup>+</sup> | `LiquidAI/LFM2-VL-450M`, `LiquidAI/LFM2-VL-3B`, `LiquidAI/LFM2-VL-8B-A1B`, etc. | ✅︎ | ✅︎ |
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ |

View File

@@ -201,6 +201,34 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
)
# Kimi-Audio-7B-Instruct
def run_kimi_audio(question: str, audio_count: int) -> ModelRequestData:
"""Kimi-Audio-7B-Instruct for audio transcription and understanding."""
model_name = "moonshotai/Kimi-Audio-7B-Instruct"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count},
)
# Kimi-Audio uses <|im_kimia_text_blank|> as placeholder for audio features
audio_placeholder = "<|im_kimia_text_blank|>" * audio_count
# Default prompt for transcription
if not question:
question = "Please transcribe the audio"
prompt = f"{audio_placeholder}{question}"
# Stop at EOS token (151644) to prevent repetition
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
stop_token_ids=[151644],
)
# MiDashengLM
def run_midashenglm(question: str, audio_count: int):
model_name = "mispeech/midashenglm-7b"
@@ -485,6 +513,7 @@ model_example_map = {
"glmasr": run_glmasr,
"funaudiochat": run_funaudiochat,
"granite_speech": run_granite_speech,
"kimi_audio": run_kimi_audio,
"midashenglm": run_midashenglm,
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,

View File

@@ -198,13 +198,17 @@ def get_text_token_prompts(
mm_counts,
mm_options={},
)
assert isinstance(inputs.prompt, str)
text_prompt = inputs.prompt
token_prompt = tokenizer.encode(
text_prompt,
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type, True),
)
# Some models (e.g., Kimi-Audio) return token IDs directly instead of str
if isinstance(inputs.prompt, list):
text_prompt = None
token_prompt = inputs.prompt
else:
assert isinstance(inputs.prompt, str)
text_prompt = inputs.prompt
token_prompt = tokenizer.encode(
text_prompt,
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type, True),
)
return text_prompt, token_prompt

View File

@@ -857,6 +857,15 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Kwai-Keye/Keye-VL-1_5-8B",
trust_remote_code=True,
),
"MoonshotKimiaForCausalLM": _HfExamplesInfo(
"moonshotai/Kimi-Audio-7B-Instruct",
tokenizer_mode="kimi_audio",
trust_remote_code=True,
),
"KimiK25ForConditionalGeneration": _HfExamplesInfo(
"moonshotai/Kimi-K2.5",
trust_remote_code=True,
),
"KimiVLForConditionalGeneration": _HfExamplesInfo(
"moonshotai/Kimi-VL-A3B-Instruct",
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"},
@@ -870,10 +879,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
)
},
),
"KimiK25ForConditionalGeneration": _HfExamplesInfo(
"moonshotai/Kimi-K2.5",
trust_remote_code=True,
),
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
"lightonai/LightOnOCR-1B-1025"
),

View File

@@ -103,6 +103,12 @@ def can_initialize(
"pickle error when loading `transformers.models.auto.CONFIG_MAPPING`"
)
if model_arch == "MoonshotKimiaForCausalLM":
pytest.skip(
"Kimi-Audio requires SpeechToTextConfig "
"which is not configured in test environment"
)
if model_arch in ["DeepseekV32ForCausalLM", "GlmMoeDsaForCausalLM"]:
from vllm.platforms import current_platform

View File

@@ -0,0 +1,725 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Kimi-Audio model compatible with HuggingFace weights."""
import os
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, ClassVar, Literal
import numpy as np
import torch
import torch.nn as nn
from safetensors import safe_open
from transformers import BatchFeature
from transformers import WhisperConfig as HFWhisperConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
)
from vllm.model_executor.models.interfaces import (
SupportsMultiModal,
SupportsPP,
SupportsTranscription,
)
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
)
from vllm.model_executor.models.whisper import WhisperEncoder
from vllm.model_executor.models.whisper_utils import ISO639_1_SUPPORTED_LANGS
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import (
AudioItem,
DictEmbeddingItems,
ModalityData,
ModalityDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseProcessingInfo,
PromptReplacement,
)
from vllm.multimodal.processing.processor import BaseMultiModalProcessor
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
from vllm.transformers_utils.processor import cached_feature_extractor_from_config
from vllm.transformers_utils.processors.kimi_audio import KimiAudioProcessor
from vllm.v1.sample.metadata import SamplingMetadata
# Kimi-Audio constants
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
"""Compute output lengths after Whisper feature extraction.
Whisper processes audio through multiple conv layers with stride=2,
producing 13 output features per 100 input samples.
"""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
class KimiAudioWhisperEncoder(WhisperEncoder):
"""WhisperEncoder for Kimi-Audio with packed_modules_mapping."""
# packed_modules_mapping for Q/K/V fusion during weight loading
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"kv_proj": ["k_proj", "v_proj"],
}
def __init__(
self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
):
# Load Whisper config from subfolder (authoritative source)
# Kimi-Audio stores Whisper config in whisper-large-v3/config.json
model_path = vllm_config.model_config.model
whisper_config_path = os.path.join(model_path, KIMIA_WHISPER_SUBFOLDER)
# Load WhisperConfig from the subfolder
whisper_config = HFWhisperConfig.from_pretrained(whisper_config_path)
# Temporarily replace hf_config for WhisperEncoder.__init__()
original_config = vllm_config.model_config.hf_config
vllm_config.model_config.hf_config = whisper_config
super().__init__(
vllm_config=vllm_config, prefix=prefix, init_in_fp32=init_in_fp32
)
# Restore original config
vllm_config.model_config.hf_config = original_config
# -----------------------------------------------------------------------------
# Processing Info, Dummy Inputs, and MultiModal Processor
# (Following Qwen3ASR pattern - same file as model)
# -----------------------------------------------------------------------------
class KimiAudioProcessingInfo(BaseProcessingInfo):
"""Processing info for vLLM registry."""
def get_hf_config(self):
return self.ctx.model_config.hf_config
def get_hf_processor(self, **kwargs: object) -> KimiAudioProcessor:
"""Get KimiAudioProcessor with feature extractor and tokenizer."""
# Use vLLM's cached loader for feature extractor
feature_extractor = cached_feature_extractor_from_config(
self.ctx.model_config,
subfolder=KIMIA_WHISPER_SUBFOLDER,
)
# Use vLLM's standard tokenizer loading (respects tokenizer_mode)
tokenizer = self.get_tokenizer()
# Construct processor directly
return KimiAudioProcessor(
feature_extractor=feature_extractor,
tokenizer=tokenizer,
)
def get_feature_extractor(self, **kwargs: object):
"""Get feature extractor using vLLM's cached loader."""
return cached_feature_extractor_from_config(
self.ctx.model_config, subfolder=KIMIA_WHISPER_SUBFOLDER
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 1}
def get_data_parser(self) -> "KimiAudioMultiModalDataParser":
"""Get data parser for audio inputs."""
return KimiAudioMultiModalDataParser(
expected_hidden_size=self._get_expected_hidden_size(),
)
class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo]):
"""Dummy inputs builder for vLLM registry."""
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]:
"""Return dummy text as token IDs directly."""
num_audios = mm_counts.get("audio", 0)
if num_audios == 0:
return [198] # "Transcribe" tokenized
# Return as token IDs directly to avoid tokenizer issues
return [
KimiAudioProcessor.KIMIA_MEDIA_BEGIN,
KimiAudioProcessor.KIMIA_TEXT_BLANK,
KimiAudioProcessor.KIMIA_MEDIA_END,
] * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, Any] | None = None,
) -> dict[str, Any]:
num_audios = mm_counts.get("audio", 0)
if num_audios == 0:
return {}
feature_extractor = self.info.get_feature_extractor()
target_audio_length = (
min(feature_extractor.chunk_length, 30) * feature_extractor.sampling_rate
)
return {
"audio": self._get_dummy_audios(
length=target_audio_length, num_audios=num_audios
),
}
# Field config for Kimi-Audio multimodal data
_KIMIAUDIO_FIELD_CONFIG = {
"whisper_input_features": MultiModalFieldConfig.batched("audio"),
"feature_attention_mask": MultiModalFieldConfig.batched("audio"),
}
class KimiAudioMultiModalDataParser(MultiModalDataParser):
"""Custom data parser for Kimi-Audio multimodal data."""
def __init__(self, **kwargs):
# Whisper expects 16kHz audio
super().__init__(target_sr=16000, **kwargs)
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"whisper_input_features", "feature_attention_mask"},
fields_factory=lambda hf_inputs: _KIMIAUDIO_FIELD_CONFIG,
)
return super()._parse_audio_data(data)
class KimiAudioMultiModalProcessor(BaseMultiModalProcessor[KimiAudioProcessingInfo]):
"""vLLM multi-modal processor wrapper for Kimi-Audio."""
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
"""Call the HuggingFace processor."""
# Convert mm_data format: {'audios': [...]} -> {'audio': ...}
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])
# Convert audio format: [(array, sr), ...] -> [array, ...]
# KimiAudioProcessor expects raw numpy arrays
if audios:
audio_arrays = []
for aud in audios:
if isinstance(aud, (tuple, list)) and len(aud) == 2:
# Format: (audio_array, sampling_rate)
audio_arrays.append(aud[0])
elif isinstance(aud, np.ndarray):
audio_arrays.append(aud)
else:
audio_arrays.append(aud)
mm_data["audio"] = audio_arrays
# Use the context's call_hf_processor for proper handling
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, Any]:
"""Get multi-modal field configuration."""
return _KIMIAUDIO_FIELD_CONFIG
def _get_prompt_updates(
self,
mm_items,
hf_processor_mm_kwargs,
out_mm_kwargs,
) -> Sequence[PromptReplacement]:
"""Get prompt updates for audio tokens."""
# Get audio feature lengths from processed output
out_mm_data = out_mm_kwargs.get_data()
feature_attention_mask = out_mm_data.get("feature_attention_mask")
if feature_attention_mask is not None:
audio_output_lens = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1)
)
audio_output_lengths = audio_output_lens.tolist()
else:
audio_output_lengths = []
def get_replacement_kimiaudio(item_idx: int):
num_features = (
audio_output_lengths[item_idx]
if item_idx < len(audio_output_lengths)
else 376
)
if num_features == 0:
num_features = 376 # Default Kimi-Audio sequence length
# Return the placeholder token ID repeated num_features times
return [KimiAudioProcessor.KIMIA_TEXT_BLANK] * num_features
# Use the token ID as target (as a list)
return [
PromptReplacement(
modality="audio",
target=[KimiAudioProcessor.KIMIA_TEXT_BLANK],
replacement=get_replacement_kimiaudio,
),
]
# -----------------------------------------------------------------------------
# Model Definition
# -----------------------------------------------------------------------------
class KimiAudioMultiModalProjector(nn.Module):
"""Projects Whisper features to LLM embedding space.
Kimi-Audio VQ-Adaptor architecture:
Custom Whisper (5120) → Linear[5120→3584] → Linear[3584→3584] → LayerNorm
"""
def __init__(
self,
whisper_dim: int = 5120, # Kimi-Audio custom Whisper encoder dim
llm_dim: int = 3584,
prefix: str = "",
):
super().__init__()
self.whisper_dim = whisper_dim
self.llm_dim = llm_dim
# VQ-Adaptor layers (exact checkpoint structure)
# layers.0: Linear[5120 → 3584]
self.vq_adaptor_layers_0 = nn.Linear(whisper_dim, llm_dim)
# layers.3: Linear[3584 → 3584]
self.vq_adaptor_layers_3 = nn.Linear(llm_dim, llm_dim)
# layers.4: LayerNorm[3584]
self.vq_adaptor_layers_4 = nn.LayerNorm(llm_dim)
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
# Project: [B, T, 5120] → [B, T, 3584]
hidden = self.vq_adaptor_layers_0(audio_features)
hidden = torch.nn.functional.gelu(hidden)
hidden = self.vq_adaptor_layers_3(hidden)
hidden = self.vq_adaptor_layers_4(hidden)
return hidden
@MULTIMODAL_REGISTRY.register_processor(
KimiAudioMultiModalProcessor,
info=KimiAudioProcessingInfo,
dummy_inputs=KimiAudioDummyInputsBuilder,
)
class KimiAudioForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsPP,
SupportsTranscription,
):
"""Kimi-Audio model for ASR transcription."""
# Kimi-Audio supports a subset of Whisper's supported languages
supported_languages: ClassVar[Mapping[str, str]] = {
k: ISO639_1_SUPPORTED_LANGS[k]
for k in ["zh", "en", "ja", "ko", "de", "fr", "es", "it", "pt", "ru", "ar"]
}
supports_transcription: ClassVar[Literal[True]] = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Audio projector (VQ-Adaptor)
"model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.",
"model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.",
"model.vq_adaptor.layers.4.": "multi_modal_projector.vq_adaptor_layers_4.",
# Language model
"model.layers.": "language_model.model.layers.",
# Embeddings and output
"model.embed_tokens.": "language_model.model.embed_tokens.",
"model.norm.": "language_model.model.norm.",
"lm_head.": "language_model.lm_head.",
}
)
# Audio placeholder token sequence
AUDIO_PLACEHOLDER = "<|im_media_begin|><|im_kimia_text_blank|><|im_media_end|>"
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
return cls.AUDIO_PLACEHOLDER if modality.startswith("audio") else None
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.multimodal_config = vllm_config.model_config.multimodal_config
self.model_path = vllm_config.model_config.model
self.audio_tower = KimiAudioWhisperEncoder(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "audio_tower"),
)
self.multi_modal_projector = KimiAudioMultiModalProjector(
whisper_dim=getattr(self.config, "kimia_adaptor_input_dim", 5120),
llm_dim=self.config.hidden_size,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config.with_hf_config(
self.config, architectures=["Qwen2ForCausalLM"]
),
prefix=maybe_prefix(prefix, "language_model"),
)
self.logits_processor = LogitsProcessor(
self.config.vocab_size,
self.config.vocab_size,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> dict[str, torch.Tensor] | None:
whisper_input_features = kwargs.pop("whisper_input_features", None)
if whisper_input_features is None:
return None
return {"whisper_input_features": whisper_input_features}
def _process_audio_input(
self, audio_input: dict[str, torch.Tensor]
) -> torch.Tensor:
input_features = audio_input["whisper_input_features"]
# KimiAudioWhisperEncoder expects list of tensors
if input_features.dim() == 3:
input_features = input_features.unbind(dim=0)
# Run through Whisper encoder
audio_features = self.audio_tower(input_features)
# Reshape for 4x downsampling (Whisper outputs at 50Hz, need 12.5Hz)
B, T, D = audio_features.shape
if T % 4 != 0:
pad_len = 4 - (T % 4)
audio_features = torch.nn.functional.pad(audio_features, (0, 0, 0, pad_len))
T = audio_features.shape[1] # Update T after padding
audio_features = audio_features.reshape(B, T // 4, D * 4)
# Project to LLM dimension
audio_embeds = self.multi_modal_projector(audio_features)
return audio_embeds
def embed_multimodal(self, **kwargs: object) -> list[torch.Tensor] | None:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return []
audio_embeds = self._process_audio_input(audio_input)
# audio_embeds shape: [batch_size, seq_len, hidden_dim]
# Return as list of 2D tensors, one per batch item
if audio_embeds.dim() == 3:
# Unbind batch dimension: [B, T, D] -> list of B tensors [T, D]
return list(audio_embeds.unbind(dim=0))
else:
# Single sample: [T, D] -> wrap in list
return [audio_embeds]
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: tuple[torch.Tensor, ...] | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
"""Embed input IDs and fuse with audio embeddings.
Kimi-Audio fusion: inputs_embeds = (text_emb + audio_emb) × √2
For PP compatibility, we use the is_multimodal mask from vLLM engine
which is correctly computed per pipeline stage.
"""
# Get text embeddings
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
# is_multimodal must be provided for PP to work correctly
if is_multimodal is None or not is_multimodal.any():
return inputs_embeds
# multimodal_embeddings[0] contains audio embeddings
audio_embeds = multimodal_embeddings[0]
# Handle different tensor structures
if isinstance(audio_embeds, (list, tuple)):
audio_embeds = torch.cat(audio_embeds, dim=0)
elif audio_embeds.dim() == 3:
audio_embeds = audio_embeds.reshape(-1, audio_embeds.shape[-1])
# In PP, audio_embeds count should match is_multimodal.sum()
# For now, use embeddings sequentially
# (works for non-PP, PP needs vLLM infra fix)
num_mm_tokens = is_multimodal.sum().item()
num_audio_embeds = audio_embeds.shape[0]
# Use the minimum of available embeddings and positions
# This ensures we don't access out-of-bounds
num_to_use = min(num_audio_embeds, num_mm_tokens)
# Get positions for the tokens we'll actually process
mm_positions = is_multimodal.nonzero(as_tuple=True)[0]
actual_mm_mask = torch.zeros_like(is_multimodal)
actual_mm_mask[mm_positions[:num_to_use]] = True
# Use corresponding embeddings
used_audio_embeds = audio_embeds[:num_to_use]
# Save text embeddings at multimodal positions
text_at_mm_positions = inputs_embeds[actual_mm_mask].clone()
# Replace text with audio at multimodal positions
inputs_embeds[actual_mm_mask] = used_audio_embeds.to(dtype=inputs_embeds.dtype)
# Apply Kimi-Audio's unique fusion formula: (text + audio) × √2
inputs_embeds[actual_mm_mask] = (
inputs_embeds[actual_mm_mask] + text_at_mm_positions
) * (2**0.5)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata | None = None,
) -> torch.Tensor | None:
logits = self.logits_processor(
self.language_model.lm_head, hidden_states, sampling_metadata
)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
"""Load weights, skipping MIMO layers (TTS-only) for ASR."""
# Filter out MIMO/TTS weights since we only do ASR (speech-to-text)
skipped_patterns = [
"mimo_layers.",
"mimo_output.",
"mimo_norm.",
"audio_decoder.",
]
# Filter weights
filtered_weights = [
(name, param)
for name, param in weights
if not any(pattern in name for pattern in skipped_patterns)
]
# Separate main weights (non-Whisper) from Whisper weights
main_weights = [
(name, param)
for name, param in filtered_weights
if not name.startswith("audio_tower.")
]
# Load main model weights (LLM + projector) with mapper
loader = AutoWeightsLoader(self)
loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper)
# Load Whisper encoder weights from subfolder
whisper_path = os.path.join(
self.model_path, f"{KIMIA_WHISPER_SUBFOLDER}/model.safetensors"
)
if os.path.exists(whisper_path):
whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
loaded.update(whisper_loaded)
return loaded
def _load_whisper_weights_from_file(self, whisper_path: str) -> set[str]:
"""Load Whisper encoder weights from safetensors file with transformations."""
if not os.path.exists(whisper_path):
return set()
# Step 1: Load raw weights from safetensors file
whisper_weights = []
with safe_open(whisper_path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
if key.startswith("model.encoder.") and "embed_positions" not in key:
new_key = key.replace("model.encoder.", "")
whisper_weights.append((new_key, f.get_tensor(key)))
# Step 2: Apply fc → mlp mapping using WeightsMapper
fc_mapper = WeightsMapper(
orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
)
whisper_mapped = list(fc_mapper.apply(whisper_weights))
# Step 3: Apply Q/K/V fusion manually
stacked_params_mapping = [
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
]
params_dict = dict(self.audio_tower.named_parameters())
whisper_loaded: set[str] = set()
for name, loaded_weight in whisper_mapped:
fused = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
fused_name = name.replace(weight_name, param_name)
if fused_name not in params_dict:
continue
param = params_dict[fused_name]
param.weight_loader(param, loaded_weight, shard_id)
whisper_loaded.add(f"audio_tower.{fused_name}")
fused = True
break
if not fused:
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
whisper_loaded.add(f"audio_tower.{name}")
# Add embed_positions which is initialized randomly
whisper_loaded.add("audio_tower.embed_positions.weight")
return whisper_loaded
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
"""Get speech-to-text config with custom processor."""
# Load feature extractor for config values
feature_extractor = cached_feature_extractor_from_config(
model_config,
subfolder=KIMIA_WHISPER_SUBFOLDER,
)
return SpeechToTextConfig(
max_audio_clip_s=feature_extractor.chunk_length,
sample_rate=feature_extractor.sampling_rate,
)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
tokenizer_cls=KimiAudioTokenizer,
tokenizer_mode=model_config.tokenizer_mode,
revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
)
if task_type not in ("transcribe", "translate"):
raise ValueError(
f"Unsupported task_type '{task_type}'. "
"Supported task types are 'transcribe' and 'translate'."
)
# Incorporate request_prompt as context/instruction if provided
user_content = (
f"{request_prompt}\n{cls.AUDIO_PLACEHOLDER}"
if request_prompt
else cls.AUDIO_PLACEHOLDER
)
prompt = (
f"<|im_kimia_user_msg_start|>{user_content}"
f"<|im_msg_end|><|im_kimia_assistant_msg_start|>"
)
prompt_token_ids = tokenizer.encode(prompt)
return TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data={"audio": audio},
)
@classmethod
def post_process_output(cls, text: str) -> str:
if not text:
return ""
return text.strip()

View File

@@ -421,6 +421,7 @@ _MULTIMODAL_MODELS = {
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
"KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"), # noqa: E501
"MoonshotKimiaForCausalLM": ("kimi_audio", "KimiAudioForConditionalGeneration"), # noqa: E501
"LightOnOCRForConditionalGeneration": (
"lightonocr",
"LightOnOCRForConditionalGeneration",

View File

@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, cast
from vllm.config import VllmConfig
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
from vllm.tokenizers.registry import get_tokenizer
from .hf import HfRenderer, HfTokenizer
class KimiAudioRenderer(HfRenderer):
"""Renderer for Kimi-Audio models.
This renderer uses HfRenderer internally with a custom TikToken tokenizer.
"""
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "HfRenderer":
"""Create an HfRenderer instance for Kimi-Audio models."""
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
# Extract tokenizer_name from kwargs (already processed by
# tokenizer_args_from_config for ModelScope/GGUF/etc)
tokenizer_name = tokenizer_kwargs.pop(
"tokenizer_name", model_config.tokenizer
)
# Remove tokenizer_cls from kwargs to avoid duplicate argument
tokenizer_kwargs = {
k: v for k, v in tokenizer_kwargs.items() if k != "tokenizer_cls"
}
# Use get_tokenizer directly instead of cached_get_tokenizer
# (KimiAudioTokenizer doesn't work with get_cached_tokenizer)
tokenizer = cast(
HfTokenizer,
get_tokenizer(
tokenizer_name,
tokenizer_cls=KimiAudioTokenizer, # type: ignore[arg-type]
**tokenizer_kwargs,
),
)
return HfRenderer(config, tokenizer)

View File

@@ -19,6 +19,7 @@ _VLLM_RENDERERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
"hf": ("hf", "HfRenderer"),
"grok2": ("grok2", "Grok2Renderer"),
"kimi_audio": ("kimi_audio", "KimiAudioRenderer"),
"mistral": ("mistral", "MistralRenderer"),
"qwen_vl": ("qwen_vl", "QwenVLRenderer"),
"terratorch": ("terratorch", "TerratorchRenderer"),
@@ -74,10 +75,18 @@ RENDERER_REGISTRY = RendererRegistry(
def renderer_from_config(config: "VllmConfig", **kwargs):
model_config = config.model_config
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
model_config, **kwargs
)
# Override tokenizer_mode for Kimi-Audio models
if model_config.architecture == "MoonshotKimiaForCausalLM":
tokenizer_mode = "kimi_audio"
# Update model_config so other components (e.g., multimodal registry)
# also use the correct tokenizer mode
model_config.tokenizer_mode = "kimi_audio"
if (
model_config.tokenizer_mode == "auto"
and model_config.model_impl == "terratorch"

View File

@@ -0,0 +1,410 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tokenizer for Kimi-Audio using TikToken."""
import contextlib
import json
from pathlib import Path
from typing import Any, overload
import pybase64
import tiktoken
from huggingface_hub import hf_hub_download
from transformers import AddedToken, BatchEncoding
from transformers.utils import chat_template_utils as hf_chat_utils
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.logger import init_logger
from vllm.tokenizers.protocol import TokenizerLike
logger = init_logger(__name__)
def _load_tiktoken_encoding(
vocab_file: Path, special_tokens: dict[str, int]
) -> tuple[Any, dict[str, int]]:
"""Load TikToken encoding from vocab file."""
mergeable_ranks: dict[bytes, int] = {}
with open(vocab_file, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split()
if len(parts) == 2:
token_b64 = parts[0]
rank = int(parts[1])
token_bytes = pybase64.b64decode(token_b64)
mergeable_ranks[token_bytes] = rank
tokenizer = tiktoken.Encoding(
name=str(vocab_file),
pat_str=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|"""
r""" ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
mergeable_ranks=mergeable_ranks,
special_tokens=special_tokens,
)
return tokenizer, special_tokens
class KimiAudioTokenizer(TokenizerLike):
"""TikToken tokenizer for Kimi-Audio."""
@classmethod
def from_pretrained(
cls,
path_or_repo_id: str | Path,
*args,
trust_remote_code: bool = False,
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> "KimiAudioTokenizer":
if args:
logger.debug_once("Ignoring extra positional args for KimiAudioTokenizer.")
path = Path(path_or_repo_id)
if path.is_file():
vocab_file = path
elif path.is_dir():
vocab_file = path / "tiktoken.model"
if not vocab_file.is_file():
vocab_file = path / "tokenizer.model"
else:
# Download from HuggingFace Hub
repo_id = str(path_or_repo_id)
# Try to download tiktoken.model or tokenizer.model
try:
vocab_path = hf_hub_download(
repo_id=repo_id,
filename="tiktoken.model",
revision=revision,
local_dir=download_dir,
)
vocab_file = Path(vocab_path)
except Exception:
try:
vocab_path = hf_hub_download(
repo_id=repo_id,
filename="tokenizer.model",
revision=revision,
local_dir=download_dir,
)
vocab_file = Path(vocab_path)
except Exception as exc:
raise ValueError(
f"Could not find tiktoken.model or tokenizer.model in {repo_id}"
) from exc
# Also download tokenizer_config.json if available
with contextlib.suppress(Exception):
hf_hub_download(
repo_id=repo_id,
filename="tokenizer_config.json",
revision=revision,
local_dir=download_dir,
)
if not vocab_file.is_file():
raise FileNotFoundError(f"tiktoken.model not found at {vocab_file}.")
return cls(
vocab_file=vocab_file,
name_or_path=str(path_or_repo_id),
truncation_side=kwargs.get("truncation_side", "left"),
)
def __init__(
self,
*,
vocab_file: Path,
name_or_path: str,
truncation_side: str,
) -> None:
super().__init__()
self.name_or_path = name_or_path
self._truncation_side = truncation_side
self._vocab_file = vocab_file
# Load special tokens from tokenizer_config.json
special_tokens: dict[str, int] = {}
tokenizer_config = vocab_file.parent / "tokenizer_config.json"
if tokenizer_config.is_file():
with open(tokenizer_config, encoding="utf-8") as f:
config = json.load(f)
# Extract special tokens from added_tokens_decoder
added_tokens = config.get("added_tokens_decoder", {})
for token_id_str, token_info in added_tokens.items():
token_id = int(token_id_str)
content = token_info.get("content", "")
if content:
special_tokens[content] = token_id
self._tokenizer, self._special_tokens = _load_tiktoken_encoding(
vocab_file, special_tokens
)
# Build token <-> ID mappings
self._token_to_id: dict[str, int] = {}
self._id_to_token: dict[int, str] = {}
for token_bytes, token_id in self._tokenizer._mergeable_ranks.items():
token_str = token_bytes.decode("utf-8", errors="replace")
self._token_to_id[token_str] = token_id
self._id_to_token[token_id] = token_str
# Initialize added_tokens_decoder before adding special tokens
self._added_tokens_decoder: dict[int, Any] = {}
# Add Kimi-Audio special tokens
self._add_kimiaudio_special_tokens()
# Set default special token IDs (will be updated when special tokens are added)
self._bos_token_id = 151643 # Kimi-Audio BOS
self._eos_token_id = 151644 # Kimi-Audio EOS
self._pad_token_id = self._eos_token_id
self._unk_token_id = self._pad_token_id
self._max_chars_per_token = max(
(len(tok) for tok in self._token_to_id), default=10
)
def _add_kimiaudio_special_tokens(self) -> None:
"""Add Kimi-Audio special tokens to the tokenizer."""
# Tokens should already be in self._special_tokens from tokenizer_config.json
# Just add them to added_tokens_decoder for compatibility
kimiaudio_special_tokens = {
"<|im_media_begin|>": 151661,
"<|im_media_end|>": 151663,
"<|im_kimia_text_blank|>": 151666,
"<|im_msg_end|>": 151645,
"<|im_kimia_user_msg_start|>": 151670,
"<|im_kimia_assistant_msg_start|>": 151671,
}
for token_str, token_id in kimiaudio_special_tokens.items():
# Only add if not already present
if token_id not in self._added_tokens_decoder:
self._added_tokens_decoder[token_id] = AddedToken(
token_str, single_word=True, normalized=False, special=True
)
# Also ensure it's in _token_to_id and _id_to_token
if token_str not in self._token_to_id:
self._token_to_id[token_str] = token_id
if token_id not in self._id_to_token:
self._id_to_token[token_id] = token_str
def num_special_tokens_to_add(self) -> int:
return 0
@property
def all_special_tokens(self) -> list[str]:
return list(self._added_tokens_decoder.values())
@property
def all_special_ids(self) -> list[int]:
return list(self._added_tokens_decoder.keys())
@property
def bos_token_id(self) -> int:
return self._bos_token_id
@property
def eos_token_id(self) -> int:
return self._eos_token_id
@property
def pad_token_id(self) -> int:
return self._pad_token_id
@property
def is_fast(self) -> bool:
return False
@property
def vocab_size(self) -> int:
return self._tokenizer.n_vocab
@property
def max_token_id(self) -> int:
return self._tokenizer.n_vocab - 1
@property
def max_chars_per_token(self) -> int:
return self._max_chars_per_token
@property
def truncation_side(self) -> str:
return self._truncation_side
@property
def added_tokens_decoder(self) -> dict[int, Any]:
return self._added_tokens_decoder
@added_tokens_decoder.setter
def added_tokens_decoder(self, value: dict[int, Any]) -> None:
"""Set added tokens decoder and update special token IDs."""
self._added_tokens_decoder = value
# Update special token IDs if known tokens are added
for token_id, token in value.items():
token_str = str(token) if hasattr(token, "__str__") else token
if "<|im_kimia_user_msg_start|>" in token_str:
self._bos_token_id = token_id
elif "<|im_msg_end|>" in token_str or "<|im_end|>" in token_str:
self._eos_token_id = token_id
def get_vocab(self) -> dict[str, int]:
return dict(self._token_to_id)
def __len__(self) -> int:
"""Return vocab size for compatibility with HF tokenizer interface."""
return self._tokenizer.n_vocab
def get_added_vocab(self) -> dict[str, int]:
return {
str(token): token_id
for token_id, token in self._added_tokens_decoder.items()
}
def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]:
if max_length is None or len(tokens) <= max_length:
return tokens
if self.truncation_side == "left":
return tokens[-max_length:]
return tokens[:max_length]
def encode(
self,
text: str,
truncation: bool | None = None,
max_length: int | None = None,
add_special_tokens: bool = True,
**kwargs,
) -> list[int]:
del add_special_tokens
# Allow Kimi-Audio special tokens to be encoded
tokens = self._tokenizer.encode(
text,
allowed_special={
"<|im_media_begin|>",
"<|im_media_end|>",
"<|im_kimia_text_blank|>",
"<|im_msg_end|>",
"<|im_kimia_user_msg_start|>",
"<|im_kimia_assistant_msg_start|>",
},
)
if truncation:
tokens = self._maybe_truncate(tokens, max_length)
return tokens
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:
"""Decode token IDs to text, optionally skipping special tokens."""
if isinstance(ids, int):
ids = [ids]
if skip_special_tokens:
# Skip tokens that are in special_tokens (loaded from config)
special_ids = set(self._special_tokens.values())
ids = [token_id for token_id in ids if token_id not in special_ids]
return self._tokenizer.decode(ids)
@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ...
def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]:
if isinstance(tokens, str):
return self._token_to_id.get(tokens, self._unk_token_id)
return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]
def convert_ids_to_tokens(
self, ids: list[int], skip_special_tokens: bool = False
) -> list[str]:
tokens = []
for token_id in ids:
if skip_special_tokens and token_id in self._added_tokens_decoder:
continue
tokens.append(self._id_to_token.get(token_id, "<|unk|>"))
return tokens
def convert_tokens_to_string(self, tokens: list[str]) -> str:
token_ids = self.convert_tokens_to_ids(tokens)
return self.decode(token_ids, skip_special_tokens=False)
def __call__(
self,
text: str | list[str],
text_pair: str | None = None,
add_special_tokens: bool = True,
truncation: bool = False,
max_length: int | None = None,
**kwargs,
) -> BatchEncoding:
if text_pair is not None:
raise NotImplementedError(
"text_pair is not supported for KimiAudioTokenizer."
)
if isinstance(text, list):
input_ids_batch: list[list[int]] = [
self.encode(
item,
truncation=truncation,
max_length=max_length,
add_special_tokens=add_special_tokens,
)
for item in text
]
attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch]
return BatchEncoding(
{"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
)
input_ids = self.encode(
text,
truncation=truncation,
max_length=max_length,
add_special_tokens=add_special_tokens,
)
attention_mask = [1] * len(input_ids)
return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask})
def get_chat_template(
self, chat_template: str | None, tools: list[dict[str, Any]] | None = None
) -> str | None:
del tools
return chat_template
def apply_chat_template(
self,
messages: list[ChatCompletionMessageParam] | None = None,
tools: list[dict[str, Any]] | None = None,
chat_template: str | None = None,
tokenize: bool = False,
**kwargs,
) -> str | list[int]:
# Handle both 'messages' (protocol) and 'conversation' (caller) parameter names
conversation = messages if messages is not None else kwargs.get("conversation")
if conversation is None:
raise ValueError("Either 'messages' or 'conversation' must be provided.")
template = self.get_chat_template(chat_template, tools=tools)
if template is None:
raise ValueError(
"No chat template available. Provide `chat_template` explicitly."
)
# Use render_jinja_template instead of apply_chat_template
# Note: render_jinja_template returns ([prompts], [generation_indices])
rendered, _ = hf_chat_utils.render_jinja_template(
conversation,
chat_template=template,
tools=tools,
**kwargs,
)
# Extract the first (and usually only) prompt
prompt = rendered[0] if rendered else ""
if tokenize:
return self.encode(prompt, add_special_tokens=False)
return prompt

View File

@@ -35,6 +35,7 @@ _VLLM_TOKENIZERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
"grok2": ("grok2", "Grok2Tokenizer"),
"hf": ("hf", "CachedHfTokenizer"),
"kimi_audio": ("kimi_audio", "KimiAudioTokenizer"),
"mistral": ("mistral", "MistralTokenizer"),
"qwen_vl": ("qwen_vl", "QwenVLTokenizer"),
}

View File

@@ -0,0 +1,13 @@
{% set messages = conversations[0] if conversations else [] -%}
{% if messages and messages[0]['role'] == 'system' -%}
{% set loop_messages = messages[1:] -%}
{% else -%}
{% set loop_messages = messages -%}
{% endif -%}
{% for message in loop_messages -%}
{% if message['role'] == 'user' -%}
<|im_kimia_user_msg_start|>{{ message['content'] }}<|im_msg_end|><|im_kimia_assistant_msg_start|>
{%- elif message['role'] == 'assistant' -%}
{{ message['content'] }}<|im_kimia_text_eos|>
{%- endif -%}
{% endfor -%}

View File

@@ -10,23 +10,6 @@ reasons:
import importlib
_CLASS_TO_MODULE: dict[str, str] = {
"BagelProcessor": "vllm.transformers_utils.processors.bagel",
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",
"FunASRProcessor": "vllm.transformers_utils.processors.funasr",
"GLM4VProcessor": "vllm.transformers_utils.processors.glm4v",
"HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl",
"HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image",
"MistralCommonPixtralProcessor": "vllm.transformers_utils.processors.pixtral",
"MistralCommonVoxtralProcessor": "vllm.transformers_utils.processors.voxtral",
"OvisProcessor": "vllm.transformers_utils.processors.ovis",
"Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5",
"QwenVLProcessor": "vllm.transformers_utils.processors.qwen_vl",
"Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr",
}
__all__ = [
"BagelProcessor",
"DeepseekVLV2Processor",
@@ -35,6 +18,7 @@ __all__ = [
"GLM4VProcessor",
"HunYuanVLProcessor",
"HunYuanVLImageProcessor",
"KimiAudioProcessor",
"MistralCommonPixtralProcessor",
"MistralCommonVoxtralProcessor",
"OvisProcessor",
@@ -43,6 +27,23 @@ __all__ = [
"Qwen3ASRProcessor",
]
_CLASS_TO_MODULE: dict[str, str] = {
"BagelProcessor": "vllm.transformers_utils.processors.bagel",
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",
"FunASRProcessor": "vllm.transformers_utils.processors.funasr",
"GLM4VProcessor": "vllm.transformers_utils.processors.glm4v",
"HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl",
"HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image",
"KimiAudioProcessor": "vllm.transformers_utils.processors.kimi_audio",
"MistralCommonPixtralProcessor": "vllm.transformers_utils.processors.pixtral",
"MistralCommonVoxtralProcessor": "vllm.transformers_utils.processors.voxtral",
"OvisProcessor": "vllm.transformers_utils.processors.ovis",
"Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5",
"QwenVLProcessor": "vllm.transformers_utils.processors.qwen_vl",
"Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr",
}
def __getattr__(name: str):
if name in _CLASS_TO_MODULE:

View File

@@ -0,0 +1,163 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
# mypy: ignore-errors
# coding=utf-8
# Copyright 2026 The Moonshot AI team and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processor for Kimi-Audio ASR model."""
from collections.abc import Mapping
from typing import Any
import numpy as np
import torch
from transformers import AutoFeatureExtractor, BatchFeature, ProcessorMixin
from transformers.audio_utils import AudioInput
from transformers.tokenization_utils_base import TextInput
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
"""Compute output lengths after Whisper feature extraction."""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
class KimiAudioProcessor(ProcessorMixin):
r"""
Constructs a Kimi-Audio processor.
[`KimiAudioProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and a tokenizer.
See the [`~KimiAudioProcessor.__call__`] and [`~KimiAudioProcessor.decode`] for more information.
Args:
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
The audio feature extractor.
tokenizer ([`PreTrainedTokenizer`], *optional*):
The text tokenizer.
"""
# Required for ProcessorMixin
attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "AutoFeatureExtractor"
tokenizer_class = "AutoTokenizer"
# Special token IDs
KIMIA_MEDIA_BEGIN: int = 151661
KIMIA_MEDIA_END: int = 151663
KIMIA_TEXT_BLANK: int = 151666
# Audio processing constants
AUDIO_SEQ_LEN: int = 376
def __init__(self, feature_extractor=None, tokenizer=None, **kwargs):
# Pass feature_extractor and tokenizer to parent ProcessorMixin
super().__init__(
feature_extractor=feature_extractor,
tokenizer=tokenizer,
**kwargs,
)
def check_argument_for_proper_class(self, attribute_name: str, argument: Any):
"""Override to skip class validation for custom tokenizer."""
# Skip validation for tokenizer since KimiAudioTokenizer doesn't inherit
# from PreTrainedTokenizerBase but is compatible
if attribute_name == "tokenizer" and argument is not None:
return
# For other attributes, use default validation
super().check_argument_for_proper_class(attribute_name, argument)
def __call__(
self,
text: TextInput = None,
audio: AudioInput = None,
return_tensors: str = "pt",
**kwargs,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and audio(s).
Args:
text (`str`, `List[str]`):
The sequence or batch of sequences to be encoded.
audio (`np.ndarray`, `List[np.ndarray]`):
The audio or batch of audio to be prepared. Each audio can be a NumPy array.
return_tensors (`str`):
The type of tensors to return ("pt", "np", etc.)
"""
if text is None:
raise ValueError("You need to specify either a `text` input to process.")
# Process audio if provided
if audio is not None:
# Ensure audio is a list
if isinstance(audio, np.ndarray):
audio = [audio]
# Pad audio to hop length (required by WhisperFeatureExtractor)
hop_length = self.feature_extractor.hop_length
padded_audio = []
for aud in audio:
length = aud.shape[-1]
if length % hop_length != 0:
pad_length = hop_length - (length % hop_length)
aud = np.pad(
aud, (0, pad_length), mode="constant", constant_values=0
)
padded_audio.append(aud)
# Use feature_extractor directly like Qwen3ASR does
audio_inputs = self.feature_extractor(
padded_audio,
sampling_rate=16000,
padding=True,
return_attention_mask=True,
return_tensors=return_tensors,
)
# Rename to match Kimi-Audio expectations
if "input_features" in audio_inputs:
audio_inputs["whisper_input_features"] = audio_inputs.pop(
"input_features"
)
if "attention_mask" in audio_inputs:
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
"attention_mask"
)
else:
audio_inputs = {}
# Handle text input - can be string or token IDs from vLLM processor
if isinstance(text, list) and len(text) > 0 and isinstance(text[0], int):
# Text is already token IDs (from vLLM processor) - just wrap
text_inputs = {"input_ids": torch.tensor([text], dtype=torch.long)}
else:
# Text is string - tokenize
if not isinstance(text, list):
text = [text]
text_inputs = self.tokenizer(
text, return_tensors=return_tensors, padding=True
)
return BatchFeature(
data={**text_inputs, **audio_inputs},
tensor_type=return_tensors,
)