diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index edec87e6f..7e685181f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -713,8 +713,9 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `KananaVForConditionalGeneration` | Kanana-V | T + I+ | `kakaocorp/kanana-1.5-v-3b-instruct`, etc. | | ✅︎ | | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | | `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | -| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | +| `KimiAudioForConditionalGeneration` | Kimi-Audio | T + A+ | `moonshotai/Kimi-Audio-7B-Instruct` | | ✅︎ | | `KimiK25ForConditionalGeneration` | Kimi-K2.5 | T + I+ | `moonshotai/Kimi-K2.5` | | ✅︎ | +| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | | `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I+ | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ | | `Lfm2VlForConditionalGeneration` | LFM2-VL | T + I+ | `LiquidAI/LFM2-VL-450M`, `LiquidAI/LFM2-VL-3B`, `LiquidAI/LFM2-VL-8B-A1B`, etc. | ✅︎ | ✅︎ | | `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 4bf4b4e1d..f7292c468 100755 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -201,6 +201,34 @@ def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: ) +# Kimi-Audio-7B-Instruct +def run_kimi_audio(question: str, audio_count: int) -> ModelRequestData: + """Kimi-Audio-7B-Instruct for audio transcription and understanding.""" + model_name = "moonshotai/Kimi-Audio-7B-Instruct" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=2, + limit_mm_per_prompt={"audio": audio_count}, + ) + + # Kimi-Audio uses <|im_kimia_text_blank|> as placeholder for audio features + audio_placeholder = "<|im_kimia_text_blank|>" * audio_count + # Default prompt for transcription + if not question: + question = "Please transcribe the audio" + prompt = f"{audio_placeholder}{question}" + + # Stop at EOS token (151644) to prevent repetition + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + stop_token_ids=[151644], + ) + + # MiDashengLM def run_midashenglm(question: str, audio_count: int): model_name = "mispeech/midashenglm-7b" @@ -485,6 +513,7 @@ model_example_map = { "glmasr": run_glmasr, "funaudiochat": run_funaudiochat, "granite_speech": run_granite_speech, + "kimi_audio": run_kimi_audio, "midashenglm": run_midashenglm, "minicpmo": run_minicpmo, "phi4_mm": run_phi4mm, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index b6470baaa..34da19721 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -198,13 +198,17 @@ def get_text_token_prompts( mm_counts, mm_options={}, ) - assert isinstance(inputs.prompt, str) - - text_prompt = inputs.prompt - token_prompt = tokenizer.encode( - text_prompt, - add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type, True), - ) + # Some models (e.g., Kimi-Audio) return token IDs directly instead of str + if isinstance(inputs.prompt, list): + text_prompt = None + token_prompt = inputs.prompt + else: + assert isinstance(inputs.prompt, str) + text_prompt = inputs.prompt + token_prompt = tokenizer.encode( + text_prompt, + add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type, True), + ) return text_prompt, token_prompt diff --git a/tests/models/registry.py b/tests/models/registry.py index 3927b3ac0..17931079c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -857,6 +857,15 @@ _MULTIMODAL_EXAMPLE_MODELS = { "Kwai-Keye/Keye-VL-1_5-8B", trust_remote_code=True, ), + "MoonshotKimiaForCausalLM": _HfExamplesInfo( + "moonshotai/Kimi-Audio-7B-Instruct", + tokenizer_mode="kimi_audio", + trust_remote_code=True, + ), + "KimiK25ForConditionalGeneration": _HfExamplesInfo( + "moonshotai/Kimi-K2.5", + trust_remote_code=True, + ), "KimiVLForConditionalGeneration": _HfExamplesInfo( "moonshotai/Kimi-VL-A3B-Instruct", extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, @@ -870,10 +879,6 @@ _MULTIMODAL_EXAMPLE_MODELS = { ) }, ), - "KimiK25ForConditionalGeneration": _HfExamplesInfo( - "moonshotai/Kimi-K2.5", - trust_remote_code=True, - ), "LightOnOCRForConditionalGeneration": _HfExamplesInfo( "lightonai/LightOnOCR-1B-1025" ), diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 3b0747c8a..375592ba5 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -103,6 +103,12 @@ def can_initialize( "pickle error when loading `transformers.models.auto.CONFIG_MAPPING`" ) + if model_arch == "MoonshotKimiaForCausalLM": + pytest.skip( + "Kimi-Audio requires SpeechToTextConfig " + "which is not configured in test environment" + ) + if model_arch in ["DeepseekV32ForCausalLM", "GlmMoeDsaForCausalLM"]: from vllm.platforms import current_platform diff --git a/vllm/model_executor/models/kimi_audio.py b/vllm/model_executor/models/kimi_audio.py new file mode 100644 index 000000000..cb8ac2efb --- /dev/null +++ b/vllm/model_executor/models/kimi_audio.py @@ -0,0 +1,725 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Inference-only Kimi-Audio model compatible with HuggingFace weights.""" + +import os +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, ClassVar, Literal + +import numpy as np +import torch +import torch.nn as nn +from safetensors import safe_open +from transformers import BatchFeature +from transformers import WhisperConfig as HFWhisperConfig + +from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.inputs.data import PromptType, TokensPrompt +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, +) +from vllm.model_executor.models.interfaces import ( + SupportsMultiModal, + SupportsPP, + SupportsTranscription, +) +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) +from vllm.model_executor.models.whisper import WhisperEncoder +from vllm.model_executor.models.whisper_utils import ISO639_1_SUPPORTED_LANGS +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm.multimodal.parse import ( + AudioItem, + DictEmbeddingItems, + ModalityData, + ModalityDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseProcessingInfo, + PromptReplacement, +) +from vllm.multimodal.processing.processor import BaseMultiModalProcessor +from vllm.sequence import IntermediateTensors +from vllm.tokenizers import cached_get_tokenizer +from vllm.tokenizers.kimi_audio import KimiAudioTokenizer +from vllm.transformers_utils.processor import cached_feature_extractor_from_config +from vllm.transformers_utils.processors.kimi_audio import KimiAudioProcessor +from vllm.v1.sample.metadata import SamplingMetadata + +# Kimi-Audio constants +KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3" + + +def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor: + """Compute output lengths after Whisper feature extraction. + + Whisper processes audio through multiple conv layers with stride=2, + producing 13 output features per 100 input samples. + """ + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ( + ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + ) + return output_lengths + + +class KimiAudioWhisperEncoder(WhisperEncoder): + """WhisperEncoder for Kimi-Audio with packed_modules_mapping.""" + + # packed_modules_mapping for Q/K/V fusion during weight loading + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "kv_proj": ["k_proj", "v_proj"], + } + + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False + ): + # Load Whisper config from subfolder (authoritative source) + # Kimi-Audio stores Whisper config in whisper-large-v3/config.json + model_path = vllm_config.model_config.model + whisper_config_path = os.path.join(model_path, KIMIA_WHISPER_SUBFOLDER) + + # Load WhisperConfig from the subfolder + whisper_config = HFWhisperConfig.from_pretrained(whisper_config_path) + + # Temporarily replace hf_config for WhisperEncoder.__init__() + original_config = vllm_config.model_config.hf_config + vllm_config.model_config.hf_config = whisper_config + + super().__init__( + vllm_config=vllm_config, prefix=prefix, init_in_fp32=init_in_fp32 + ) + + # Restore original config + vllm_config.model_config.hf_config = original_config + + +# ----------------------------------------------------------------------------- +# Processing Info, Dummy Inputs, and MultiModal Processor +# (Following Qwen3ASR pattern - same file as model) +# ----------------------------------------------------------------------------- + + +class KimiAudioProcessingInfo(BaseProcessingInfo): + """Processing info for vLLM registry.""" + + def get_hf_config(self): + return self.ctx.model_config.hf_config + + def get_hf_processor(self, **kwargs: object) -> KimiAudioProcessor: + """Get KimiAudioProcessor with feature extractor and tokenizer.""" + # Use vLLM's cached loader for feature extractor + feature_extractor = cached_feature_extractor_from_config( + self.ctx.model_config, + subfolder=KIMIA_WHISPER_SUBFOLDER, + ) + + # Use vLLM's standard tokenizer loading (respects tokenizer_mode) + tokenizer = self.get_tokenizer() + + # Construct processor directly + return KimiAudioProcessor( + feature_extractor=feature_extractor, + tokenizer=tokenizer, + ) + + def get_feature_extractor(self, **kwargs: object): + """Get feature extractor using vLLM's cached loader.""" + return cached_feature_extractor_from_config( + self.ctx.model_config, subfolder=KIMIA_WHISPER_SUBFOLDER + ) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": 1} + + def get_data_parser(self) -> "KimiAudioMultiModalDataParser": + """Get data parser for audio inputs.""" + return KimiAudioMultiModalDataParser( + expected_hidden_size=self._get_expected_hidden_size(), + ) + + +class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo]): + """Dummy inputs builder for vLLM registry.""" + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]: + """Return dummy text as token IDs directly.""" + num_audios = mm_counts.get("audio", 0) + if num_audios == 0: + return [198] # "Transcribe" tokenized + # Return as token IDs directly to avoid tokenizer issues + return [ + KimiAudioProcessor.KIMIA_MEDIA_BEGIN, + KimiAudioProcessor.KIMIA_TEXT_BLANK, + KimiAudioProcessor.KIMIA_MEDIA_END, + ] * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, Any] | None = None, + ) -> dict[str, Any]: + num_audios = mm_counts.get("audio", 0) + if num_audios == 0: + return {} + + feature_extractor = self.info.get_feature_extractor() + target_audio_length = ( + min(feature_extractor.chunk_length, 30) * feature_extractor.sampling_rate + ) + + return { + "audio": self._get_dummy_audios( + length=target_audio_length, num_audios=num_audios + ), + } + + +# Field config for Kimi-Audio multimodal data +_KIMIAUDIO_FIELD_CONFIG = { + "whisper_input_features": MultiModalFieldConfig.batched("audio"), + "feature_attention_mask": MultiModalFieldConfig.batched("audio"), +} + + +class KimiAudioMultiModalDataParser(MultiModalDataParser): + """Custom data parser for Kimi-Audio multimodal data.""" + + def __init__(self, **kwargs): + # Whisper expects 16kHz audio + super().__init__(target_sr=16000, **kwargs) + + def _parse_audio_data( + self, + data: dict[str, torch.Tensor] | ModalityData[AudioItem], + ) -> ModalityDataItems[Any, Any] | None: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={"whisper_input_features", "feature_attention_mask"}, + fields_factory=lambda hf_inputs: _KIMIAUDIO_FIELD_CONFIG, + ) + + return super()._parse_audio_data(data) + + +class KimiAudioMultiModalProcessor(BaseMultiModalProcessor[KimiAudioProcessingInfo]): + """vLLM multi-modal processor wrapper for Kimi-Audio.""" + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + """Call the HuggingFace processor.""" + # Convert mm_data format: {'audios': [...]} -> {'audio': ...} + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) + + # Convert audio format: [(array, sr), ...] -> [array, ...] + # KimiAudioProcessor expects raw numpy arrays + if audios: + audio_arrays = [] + for aud in audios: + if isinstance(aud, (tuple, list)) and len(aud) == 2: + # Format: (audio_array, sampling_rate) + audio_arrays.append(aud[0]) + elif isinstance(aud, np.ndarray): + audio_arrays.append(aud) + else: + audio_arrays.append(aud) + mm_data["audio"] = audio_arrays + + # Use the context's call_hf_processor for proper handling + return self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=prompt, **mm_data), + dict(**mm_kwargs, **tok_kwargs), + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, Any]: + """Get multi-modal field configuration.""" + return _KIMIAUDIO_FIELD_CONFIG + + def _get_prompt_updates( + self, + mm_items, + hf_processor_mm_kwargs, + out_mm_kwargs, + ) -> Sequence[PromptReplacement]: + """Get prompt updates for audio tokens.""" + # Get audio feature lengths from processed output + out_mm_data = out_mm_kwargs.get_data() + feature_attention_mask = out_mm_data.get("feature_attention_mask") + + if feature_attention_mask is not None: + audio_output_lens = _get_feat_extract_output_lengths( + feature_attention_mask.sum(-1) + ) + audio_output_lengths = audio_output_lens.tolist() + else: + audio_output_lengths = [] + + def get_replacement_kimiaudio(item_idx: int): + num_features = ( + audio_output_lengths[item_idx] + if item_idx < len(audio_output_lengths) + else 376 + ) + if num_features == 0: + num_features = 376 # Default Kimi-Audio sequence length + # Return the placeholder token ID repeated num_features times + return [KimiAudioProcessor.KIMIA_TEXT_BLANK] * num_features + + # Use the token ID as target (as a list) + return [ + PromptReplacement( + modality="audio", + target=[KimiAudioProcessor.KIMIA_TEXT_BLANK], + replacement=get_replacement_kimiaudio, + ), + ] + + +# ----------------------------------------------------------------------------- +# Model Definition +# ----------------------------------------------------------------------------- + + +class KimiAudioMultiModalProjector(nn.Module): + """Projects Whisper features to LLM embedding space. + + Kimi-Audio VQ-Adaptor architecture: + Custom Whisper (5120) → Linear[5120→3584] → Linear[3584→3584] → LayerNorm + """ + + def __init__( + self, + whisper_dim: int = 5120, # Kimi-Audio custom Whisper encoder dim + llm_dim: int = 3584, + prefix: str = "", + ): + super().__init__() + self.whisper_dim = whisper_dim + self.llm_dim = llm_dim + + # VQ-Adaptor layers (exact checkpoint structure) + # layers.0: Linear[5120 → 3584] + self.vq_adaptor_layers_0 = nn.Linear(whisper_dim, llm_dim) + # layers.3: Linear[3584 → 3584] + self.vq_adaptor_layers_3 = nn.Linear(llm_dim, llm_dim) + # layers.4: LayerNorm[3584] + self.vq_adaptor_layers_4 = nn.LayerNorm(llm_dim) + + def forward(self, audio_features: torch.Tensor) -> torch.Tensor: + # Project: [B, T, 5120] → [B, T, 3584] + hidden = self.vq_adaptor_layers_0(audio_features) + hidden = torch.nn.functional.gelu(hidden) + hidden = self.vq_adaptor_layers_3(hidden) + hidden = self.vq_adaptor_layers_4(hidden) + return hidden + + +@MULTIMODAL_REGISTRY.register_processor( + KimiAudioMultiModalProcessor, + info=KimiAudioProcessingInfo, + dummy_inputs=KimiAudioDummyInputsBuilder, +) +class KimiAudioForConditionalGeneration( + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsTranscription, +): + """Kimi-Audio model for ASR transcription.""" + + # Kimi-Audio supports a subset of Whisper's supported languages + supported_languages: ClassVar[Mapping[str, str]] = { + k: ISO639_1_SUPPORTED_LANGS[k] + for k in ["zh", "en", "ja", "ko", "de", "fr", "es", "it", "pt", "ru", "ar"] + } + supports_transcription: ClassVar[Literal[True]] = True + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # Audio projector (VQ-Adaptor) + "model.vq_adaptor.layers.0.": "multi_modal_projector.vq_adaptor_layers_0.", + "model.vq_adaptor.layers.3.": "multi_modal_projector.vq_adaptor_layers_3.", + "model.vq_adaptor.layers.4.": "multi_modal_projector.vq_adaptor_layers_4.", + # Language model + "model.layers.": "language_model.model.layers.", + # Embeddings and output + "model.embed_tokens.": "language_model.model.embed_tokens.", + "model.norm.": "language_model.model.norm.", + "lm_head.": "language_model.lm_head.", + } + ) + + # Audio placeholder token sequence + AUDIO_PLACEHOLDER = "<|im_media_begin|><|im_kimia_text_blank|><|im_media_end|>" + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + return cls.AUDIO_PLACEHOLDER if modality.startswith("audio") else None + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.quant_config = vllm_config.quant_config + self.multimodal_config = vllm_config.model_config.multimodal_config + self.model_path = vllm_config.model_config.model + + self.audio_tower = KimiAudioWhisperEncoder( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "audio_tower"), + ) + + self.multi_modal_projector = KimiAudioMultiModalProjector( + whisper_dim=getattr(self.config, "kimia_adaptor_input_dim", 5120), + llm_dim=self.config.hidden_size, + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config.with_hf_config( + self.config, architectures=["Qwen2ForCausalLM"] + ), + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.logits_processor = LogitsProcessor( + self.config.vocab_size, + self.config.vocab_size, + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _parse_and_validate_audio_input( + self, **kwargs: object + ) -> dict[str, torch.Tensor] | None: + whisper_input_features = kwargs.pop("whisper_input_features", None) + if whisper_input_features is None: + return None + + return {"whisper_input_features": whisper_input_features} + + def _process_audio_input( + self, audio_input: dict[str, torch.Tensor] + ) -> torch.Tensor: + input_features = audio_input["whisper_input_features"] + + # KimiAudioWhisperEncoder expects list of tensors + if input_features.dim() == 3: + input_features = input_features.unbind(dim=0) + + # Run through Whisper encoder + audio_features = self.audio_tower(input_features) + + # Reshape for 4x downsampling (Whisper outputs at 50Hz, need 12.5Hz) + B, T, D = audio_features.shape + if T % 4 != 0: + pad_len = 4 - (T % 4) + audio_features = torch.nn.functional.pad(audio_features, (0, 0, 0, pad_len)) + T = audio_features.shape[1] # Update T after padding + + audio_features = audio_features.reshape(B, T // 4, D * 4) + + # Project to LLM dimension + audio_embeds = self.multi_modal_projector(audio_features) + return audio_embeds + + def embed_multimodal(self, **kwargs: object) -> list[torch.Tensor] | None: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is None: + return [] + + audio_embeds = self._process_audio_input(audio_input) + + # audio_embeds shape: [batch_size, seq_len, hidden_dim] + # Return as list of 2D tensors, one per batch item + if audio_embeds.dim() == 3: + # Unbind batch dimension: [B, T, D] -> list of B tensors [T, D] + return list(audio_embeds.unbind(dim=0)) + else: + # Single sample: [T, D] -> wrap in list + return [audio_embeds] + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: tuple[torch.Tensor, ...] | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + """Embed input IDs and fuse with audio embeddings. + + Kimi-Audio fusion: inputs_embeds = (text_emb + audio_emb) × √2 + + For PP compatibility, we use the is_multimodal mask from vLLM engine + which is correctly computed per pipeline stage. + """ + # Get text embeddings + inputs_embeds = self.language_model.model.embed_tokens(input_ids) + + if multimodal_embeddings is None or len(multimodal_embeddings) == 0: + return inputs_embeds + + # is_multimodal must be provided for PP to work correctly + if is_multimodal is None or not is_multimodal.any(): + return inputs_embeds + + # multimodal_embeddings[0] contains audio embeddings + audio_embeds = multimodal_embeddings[0] + + # Handle different tensor structures + if isinstance(audio_embeds, (list, tuple)): + audio_embeds = torch.cat(audio_embeds, dim=0) + elif audio_embeds.dim() == 3: + audio_embeds = audio_embeds.reshape(-1, audio_embeds.shape[-1]) + + # In PP, audio_embeds count should match is_multimodal.sum() + # For now, use embeddings sequentially + # (works for non-PP, PP needs vLLM infra fix) + num_mm_tokens = is_multimodal.sum().item() + num_audio_embeds = audio_embeds.shape[0] + + # Use the minimum of available embeddings and positions + # This ensures we don't access out-of-bounds + num_to_use = min(num_audio_embeds, num_mm_tokens) + + # Get positions for the tokens we'll actually process + mm_positions = is_multimodal.nonzero(as_tuple=True)[0] + actual_mm_mask = torch.zeros_like(is_multimodal) + actual_mm_mask[mm_positions[:num_to_use]] = True + + # Use corresponding embeddings + used_audio_embeds = audio_embeds[:num_to_use] + + # Save text embeddings at multimodal positions + text_at_mm_positions = inputs_embeds[actual_mm_mask].clone() + + # Replace text with audio at multimodal positions + inputs_embeds[actual_mm_mask] = used_audio_embeds.to(dtype=inputs_embeds.dtype) + + # Apply Kimi-Audio's unique fusion formula: (text + audio) × √2 + inputs_embeds[actual_mm_mask] = ( + inputs_embeds[actual_mm_mask] + text_at_mm_positions + ) * (2**0.5) + + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata | None = None, + ) -> torch.Tensor | None: + logits = self.logits_processor( + self.language_model.lm_head, hidden_states, sampling_metadata + ) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights, skipping MIMO layers (TTS-only) for ASR.""" + # Filter out MIMO/TTS weights since we only do ASR (speech-to-text) + skipped_patterns = [ + "mimo_layers.", + "mimo_output.", + "mimo_norm.", + "audio_decoder.", + ] + + # Filter weights + filtered_weights = [ + (name, param) + for name, param in weights + if not any(pattern in name for pattern in skipped_patterns) + ] + + # Separate main weights (non-Whisper) from Whisper weights + main_weights = [ + (name, param) + for name, param in filtered_weights + if not name.startswith("audio_tower.") + ] + + # Load main model weights (LLM + projector) with mapper + loader = AutoWeightsLoader(self) + loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper) + + # Load Whisper encoder weights from subfolder + whisper_path = os.path.join( + self.model_path, f"{KIMIA_WHISPER_SUBFOLDER}/model.safetensors" + ) + if os.path.exists(whisper_path): + whisper_loaded = self._load_whisper_weights_from_file(whisper_path) + loaded.update(whisper_loaded) + + return loaded + + def _load_whisper_weights_from_file(self, whisper_path: str) -> set[str]: + """Load Whisper encoder weights from safetensors file with transformations.""" + if not os.path.exists(whisper_path): + return set() + + # Step 1: Load raw weights from safetensors file + whisper_weights = [] + with safe_open(whisper_path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + if key.startswith("model.encoder.") and "embed_positions" not in key: + new_key = key.replace("model.encoder.", "") + whisper_weights.append((new_key, f.get_tensor(key))) + + # Step 2: Apply fc → mlp mapping using WeightsMapper + fc_mapper = WeightsMapper( + orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."} + ) + whisper_mapped = list(fc_mapper.apply(whisper_weights)) + + # Step 3: Apply Q/K/V fusion manually + stacked_params_mapping = [ + (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), + (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), + (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), + ] + + params_dict = dict(self.audio_tower.named_parameters()) + whisper_loaded: set[str] = set() + + for name, loaded_weight in whisper_mapped: + fused = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + fused_name = name.replace(weight_name, param_name) + if fused_name not in params_dict: + continue + + param = params_dict[fused_name] + param.weight_loader(param, loaded_weight, shard_id) + whisper_loaded.add(f"audio_tower.{fused_name}") + fused = True + break + + if not fused: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + whisper_loaded.add(f"audio_tower.{name}") + + # Add embed_positions which is initialized randomly + whisper_loaded.add("audio_tower.embed_positions.weight") + + return whisper_loaded + + @classmethod + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: + """Get speech-to-text config with custom processor.""" + # Load feature extractor for config values + feature_extractor = cached_feature_extractor_from_config( + model_config, + subfolder=KIMIA_WHISPER_SUBFOLDER, + ) + + return SpeechToTextConfig( + max_audio_clip_s=feature_extractor.chunk_length, + sample_rate=feature_extractor.sampling_rate, + ) + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + model_config: ModelConfig, + stt_config: SpeechToTextConfig, + language: str | None, + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: str | None, + ) -> PromptType: + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + tokenizer_cls=KimiAudioTokenizer, + tokenizer_mode=model_config.tokenizer_mode, + revision=model_config.tokenizer_revision, + trust_remote_code=model_config.trust_remote_code, + ) + + if task_type not in ("transcribe", "translate"): + raise ValueError( + f"Unsupported task_type '{task_type}'. " + "Supported task types are 'transcribe' and 'translate'." + ) + + # Incorporate request_prompt as context/instruction if provided + user_content = ( + f"{request_prompt}\n{cls.AUDIO_PLACEHOLDER}" + if request_prompt + else cls.AUDIO_PLACEHOLDER + ) + + prompt = ( + f"<|im_kimia_user_msg_start|>{user_content}" + f"<|im_msg_end|><|im_kimia_assistant_msg_start|>" + ) + + prompt_token_ids = tokenizer.encode(prompt) + + return TokensPrompt( + prompt_token_ids=prompt_token_ids, + multi_modal_data={"audio": audio}, + ) + + @classmethod + def post_process_output(cls, text: str) -> str: + if not text: + return "" + return text.strip() diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 34dda9b38..00bfa8c65 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -421,6 +421,7 @@ _MULTIMODAL_MODELS = { "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"), # noqa: E501 + "MoonshotKimiaForCausalLM": ("kimi_audio", "KimiAudioForConditionalGeneration"), # noqa: E501 "LightOnOCRForConditionalGeneration": ( "lightonocr", "LightOnOCRForConditionalGeneration", diff --git a/vllm/renderers/kimi_audio.py b/vllm/renderers/kimi_audio.py new file mode 100644 index 000000000..4df2cb78c --- /dev/null +++ b/vllm/renderers/kimi_audio.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, cast + +from vllm.config import VllmConfig +from vllm.tokenizers.kimi_audio import KimiAudioTokenizer +from vllm.tokenizers.registry import get_tokenizer + +from .hf import HfRenderer, HfTokenizer + + +class KimiAudioRenderer(HfRenderer): + """Renderer for Kimi-Audio models. + + This renderer uses HfRenderer internally with a custom TikToken tokenizer. + """ + + @classmethod + def from_config( # type: ignore[override] + cls, + config: VllmConfig, + tokenizer_kwargs: dict[str, Any], + ) -> "HfRenderer": + """Create an HfRenderer instance for Kimi-Audio models.""" + model_config = config.model_config + if model_config.skip_tokenizer_init: + tokenizer = None + else: + # Extract tokenizer_name from kwargs (already processed by + # tokenizer_args_from_config for ModelScope/GGUF/etc) + tokenizer_name = tokenizer_kwargs.pop( + "tokenizer_name", model_config.tokenizer + ) + # Remove tokenizer_cls from kwargs to avoid duplicate argument + tokenizer_kwargs = { + k: v for k, v in tokenizer_kwargs.items() if k != "tokenizer_cls" + } + # Use get_tokenizer directly instead of cached_get_tokenizer + # (KimiAudioTokenizer doesn't work with get_cached_tokenizer) + tokenizer = cast( + HfTokenizer, + get_tokenizer( + tokenizer_name, + tokenizer_cls=KimiAudioTokenizer, # type: ignore[arg-type] + **tokenizer_kwargs, + ), + ) + + return HfRenderer(config, tokenizer) diff --git a/vllm/renderers/registry.py b/vllm/renderers/registry.py index de95505ec..90f7fd2d3 100644 --- a/vllm/renderers/registry.py +++ b/vllm/renderers/registry.py @@ -19,6 +19,7 @@ _VLLM_RENDERERS = { "deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"), "hf": ("hf", "HfRenderer"), "grok2": ("grok2", "Grok2Renderer"), + "kimi_audio": ("kimi_audio", "KimiAudioRenderer"), "mistral": ("mistral", "MistralRenderer"), "qwen_vl": ("qwen_vl", "QwenVLRenderer"), "terratorch": ("terratorch", "TerratorchRenderer"), @@ -74,10 +75,18 @@ RENDERER_REGISTRY = RendererRegistry( def renderer_from_config(config: "VllmConfig", **kwargs): model_config = config.model_config + tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config( model_config, **kwargs ) + # Override tokenizer_mode for Kimi-Audio models + if model_config.architecture == "MoonshotKimiaForCausalLM": + tokenizer_mode = "kimi_audio" + # Update model_config so other components (e.g., multimodal registry) + # also use the correct tokenizer mode + model_config.tokenizer_mode = "kimi_audio" + if ( model_config.tokenizer_mode == "auto" and model_config.model_impl == "terratorch" diff --git a/vllm/tokenizers/kimi_audio.py b/vllm/tokenizers/kimi_audio.py new file mode 100644 index 000000000..ef3f9efb8 --- /dev/null +++ b/vllm/tokenizers/kimi_audio.py @@ -0,0 +1,410 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tokenizer for Kimi-Audio using TikToken.""" + +import contextlib +import json +from pathlib import Path +from typing import Any, overload + +import pybase64 +import tiktoken +from huggingface_hub import hf_hub_download +from transformers import AddedToken, BatchEncoding +from transformers.utils import chat_template_utils as hf_chat_utils + +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.logger import init_logger +from vllm.tokenizers.protocol import TokenizerLike + +logger = init_logger(__name__) + + +def _load_tiktoken_encoding( + vocab_file: Path, special_tokens: dict[str, int] +) -> tuple[Any, dict[str, int]]: + """Load TikToken encoding from vocab file.""" + mergeable_ranks: dict[bytes, int] = {} + with open(vocab_file, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split() + if len(parts) == 2: + token_b64 = parts[0] + rank = int(parts[1]) + token_bytes = pybase64.b64decode(token_b64) + mergeable_ranks[token_bytes] = rank + + tokenizer = tiktoken.Encoding( + name=str(vocab_file), + pat_str=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|""" + r""" ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""", + mergeable_ranks=mergeable_ranks, + special_tokens=special_tokens, + ) + + return tokenizer, special_tokens + + +class KimiAudioTokenizer(TokenizerLike): + """TikToken tokenizer for Kimi-Audio.""" + + @classmethod + def from_pretrained( + cls, + path_or_repo_id: str | Path, + *args, + trust_remote_code: bool = False, + revision: str | None = None, + download_dir: str | None = None, + **kwargs, + ) -> "KimiAudioTokenizer": + if args: + logger.debug_once("Ignoring extra positional args for KimiAudioTokenizer.") + + path = Path(path_or_repo_id) + if path.is_file(): + vocab_file = path + elif path.is_dir(): + vocab_file = path / "tiktoken.model" + if not vocab_file.is_file(): + vocab_file = path / "tokenizer.model" + else: + # Download from HuggingFace Hub + repo_id = str(path_or_repo_id) + + # Try to download tiktoken.model or tokenizer.model + try: + vocab_path = hf_hub_download( + repo_id=repo_id, + filename="tiktoken.model", + revision=revision, + local_dir=download_dir, + ) + vocab_file = Path(vocab_path) + except Exception: + try: + vocab_path = hf_hub_download( + repo_id=repo_id, + filename="tokenizer.model", + revision=revision, + local_dir=download_dir, + ) + vocab_file = Path(vocab_path) + except Exception as exc: + raise ValueError( + f"Could not find tiktoken.model or tokenizer.model in {repo_id}" + ) from exc + + # Also download tokenizer_config.json if available + with contextlib.suppress(Exception): + hf_hub_download( + repo_id=repo_id, + filename="tokenizer_config.json", + revision=revision, + local_dir=download_dir, + ) + + if not vocab_file.is_file(): + raise FileNotFoundError(f"tiktoken.model not found at {vocab_file}.") + + return cls( + vocab_file=vocab_file, + name_or_path=str(path_or_repo_id), + truncation_side=kwargs.get("truncation_side", "left"), + ) + + def __init__( + self, + *, + vocab_file: Path, + name_or_path: str, + truncation_side: str, + ) -> None: + super().__init__() + self.name_or_path = name_or_path + self._truncation_side = truncation_side + self._vocab_file = vocab_file + + # Load special tokens from tokenizer_config.json + special_tokens: dict[str, int] = {} + tokenizer_config = vocab_file.parent / "tokenizer_config.json" + if tokenizer_config.is_file(): + with open(tokenizer_config, encoding="utf-8") as f: + config = json.load(f) + # Extract special tokens from added_tokens_decoder + added_tokens = config.get("added_tokens_decoder", {}) + for token_id_str, token_info in added_tokens.items(): + token_id = int(token_id_str) + content = token_info.get("content", "") + if content: + special_tokens[content] = token_id + + self._tokenizer, self._special_tokens = _load_tiktoken_encoding( + vocab_file, special_tokens + ) + + # Build token <-> ID mappings + self._token_to_id: dict[str, int] = {} + self._id_to_token: dict[int, str] = {} + for token_bytes, token_id in self._tokenizer._mergeable_ranks.items(): + token_str = token_bytes.decode("utf-8", errors="replace") + self._token_to_id[token_str] = token_id + self._id_to_token[token_id] = token_str + + # Initialize added_tokens_decoder before adding special tokens + self._added_tokens_decoder: dict[int, Any] = {} + + # Add Kimi-Audio special tokens + self._add_kimiaudio_special_tokens() + + # Set default special token IDs (will be updated when special tokens are added) + self._bos_token_id = 151643 # Kimi-Audio BOS + self._eos_token_id = 151644 # Kimi-Audio EOS + self._pad_token_id = self._eos_token_id + self._unk_token_id = self._pad_token_id + + self._max_chars_per_token = max( + (len(tok) for tok in self._token_to_id), default=10 + ) + + def _add_kimiaudio_special_tokens(self) -> None: + """Add Kimi-Audio special tokens to the tokenizer.""" + # Tokens should already be in self._special_tokens from tokenizer_config.json + # Just add them to added_tokens_decoder for compatibility + kimiaudio_special_tokens = { + "<|im_media_begin|>": 151661, + "<|im_media_end|>": 151663, + "<|im_kimia_text_blank|>": 151666, + "<|im_msg_end|>": 151645, + "<|im_kimia_user_msg_start|>": 151670, + "<|im_kimia_assistant_msg_start|>": 151671, + } + + for token_str, token_id in kimiaudio_special_tokens.items(): + # Only add if not already present + if token_id not in self._added_tokens_decoder: + self._added_tokens_decoder[token_id] = AddedToken( + token_str, single_word=True, normalized=False, special=True + ) + # Also ensure it's in _token_to_id and _id_to_token + if token_str not in self._token_to_id: + self._token_to_id[token_str] = token_id + if token_id not in self._id_to_token: + self._id_to_token[token_id] = token_str + + def num_special_tokens_to_add(self) -> int: + return 0 + + @property + def all_special_tokens(self) -> list[str]: + return list(self._added_tokens_decoder.values()) + + @property + def all_special_ids(self) -> list[int]: + return list(self._added_tokens_decoder.keys()) + + @property + def bos_token_id(self) -> int: + return self._bos_token_id + + @property + def eos_token_id(self) -> int: + return self._eos_token_id + + @property + def pad_token_id(self) -> int: + return self._pad_token_id + + @property + def is_fast(self) -> bool: + return False + + @property + def vocab_size(self) -> int: + return self._tokenizer.n_vocab + + @property + def max_token_id(self) -> int: + return self._tokenizer.n_vocab - 1 + + @property + def max_chars_per_token(self) -> int: + return self._max_chars_per_token + + @property + def truncation_side(self) -> str: + return self._truncation_side + + @property + def added_tokens_decoder(self) -> dict[int, Any]: + return self._added_tokens_decoder + + @added_tokens_decoder.setter + def added_tokens_decoder(self, value: dict[int, Any]) -> None: + """Set added tokens decoder and update special token IDs.""" + self._added_tokens_decoder = value + # Update special token IDs if known tokens are added + for token_id, token in value.items(): + token_str = str(token) if hasattr(token, "__str__") else token + if "<|im_kimia_user_msg_start|>" in token_str: + self._bos_token_id = token_id + elif "<|im_msg_end|>" in token_str or "<|im_end|>" in token_str: + self._eos_token_id = token_id + + def get_vocab(self) -> dict[str, int]: + return dict(self._token_to_id) + + def __len__(self) -> int: + """Return vocab size for compatibility with HF tokenizer interface.""" + return self._tokenizer.n_vocab + + def get_added_vocab(self) -> dict[str, int]: + return { + str(token): token_id + for token_id, token in self._added_tokens_decoder.items() + } + + def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]: + if max_length is None or len(tokens) <= max_length: + return tokens + if self.truncation_side == "left": + return tokens[-max_length:] + return tokens[:max_length] + + def encode( + self, + text: str, + truncation: bool | None = None, + max_length: int | None = None, + add_special_tokens: bool = True, + **kwargs, + ) -> list[int]: + del add_special_tokens + # Allow Kimi-Audio special tokens to be encoded + tokens = self._tokenizer.encode( + text, + allowed_special={ + "<|im_media_begin|>", + "<|im_media_end|>", + "<|im_kimia_text_blank|>", + "<|im_msg_end|>", + "<|im_kimia_user_msg_start|>", + "<|im_kimia_assistant_msg_start|>", + }, + ) + if truncation: + tokens = self._maybe_truncate(tokens, max_length) + return tokens + + def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str: + """Decode token IDs to text, optionally skipping special tokens.""" + if isinstance(ids, int): + ids = [ids] + if skip_special_tokens: + # Skip tokens that are in special_tokens (loaded from config) + special_ids = set(self._special_tokens.values()) + ids = [token_id for token_id in ids if token_id not in special_ids] + return self._tokenizer.decode(ids) + + @overload + def convert_tokens_to_ids(self, tokens: str) -> int: ... + + @overload + def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ... + + def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + if isinstance(tokens, str): + return self._token_to_id.get(tokens, self._unk_token_id) + return [self._token_to_id.get(token, self._unk_token_id) for token in tokens] + + def convert_ids_to_tokens( + self, ids: list[int], skip_special_tokens: bool = False + ) -> list[str]: + tokens = [] + for token_id in ids: + if skip_special_tokens and token_id in self._added_tokens_decoder: + continue + tokens.append(self._id_to_token.get(token_id, "<|unk|>")) + return tokens + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + token_ids = self.convert_tokens_to_ids(tokens) + return self.decode(token_ids, skip_special_tokens=False) + + def __call__( + self, + text: str | list[str], + text_pair: str | None = None, + add_special_tokens: bool = True, + truncation: bool = False, + max_length: int | None = None, + **kwargs, + ) -> BatchEncoding: + if text_pair is not None: + raise NotImplementedError( + "text_pair is not supported for KimiAudioTokenizer." + ) + + if isinstance(text, list): + input_ids_batch: list[list[int]] = [ + self.encode( + item, + truncation=truncation, + max_length=max_length, + add_special_tokens=add_special_tokens, + ) + for item in text + ] + attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch] + return BatchEncoding( + {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + ) + + input_ids = self.encode( + text, + truncation=truncation, + max_length=max_length, + add_special_tokens=add_special_tokens, + ) + attention_mask = [1] * len(input_ids) + return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}) + + def get_chat_template( + self, chat_template: str | None, tools: list[dict[str, Any]] | None = None + ) -> str | None: + del tools + return chat_template + + def apply_chat_template( + self, + messages: list[ChatCompletionMessageParam] | None = None, + tools: list[dict[str, Any]] | None = None, + chat_template: str | None = None, + tokenize: bool = False, + **kwargs, + ) -> str | list[int]: + # Handle both 'messages' (protocol) and 'conversation' (caller) parameter names + conversation = messages if messages is not None else kwargs.get("conversation") + if conversation is None: + raise ValueError("Either 'messages' or 'conversation' must be provided.") + template = self.get_chat_template(chat_template, tools=tools) + if template is None: + raise ValueError( + "No chat template available. Provide `chat_template` explicitly." + ) + # Use render_jinja_template instead of apply_chat_template + # Note: render_jinja_template returns ([prompts], [generation_indices]) + rendered, _ = hf_chat_utils.render_jinja_template( + conversation, + chat_template=template, + tools=tools, + **kwargs, + ) + # Extract the first (and usually only) prompt + prompt = rendered[0] if rendered else "" + if tokenize: + return self.encode(prompt, add_special_tokens=False) + return prompt diff --git a/vllm/tokenizers/registry.py b/vllm/tokenizers/registry.py index 4512f766c..63711cbe0 100644 --- a/vllm/tokenizers/registry.py +++ b/vllm/tokenizers/registry.py @@ -35,6 +35,7 @@ _VLLM_TOKENIZERS = { "deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"), "grok2": ("grok2", "Grok2Tokenizer"), "hf": ("hf", "CachedHfTokenizer"), + "kimi_audio": ("kimi_audio", "KimiAudioTokenizer"), "mistral": ("mistral", "MistralTokenizer"), "qwen_vl": ("qwen_vl", "QwenVLTokenizer"), } diff --git a/vllm/transformers_utils/chat_templates/template_kimi_audio.jinja b/vllm/transformers_utils/chat_templates/template_kimi_audio.jinja new file mode 100644 index 000000000..269359e9b --- /dev/null +++ b/vllm/transformers_utils/chat_templates/template_kimi_audio.jinja @@ -0,0 +1,13 @@ +{% set messages = conversations[0] if conversations else [] -%} +{% if messages and messages[0]['role'] == 'system' -%} + {% set loop_messages = messages[1:] -%} +{% else -%} + {% set loop_messages = messages -%} +{% endif -%} +{% for message in loop_messages -%} + {% if message['role'] == 'user' -%} + <|im_kimia_user_msg_start|>{{ message['content'] }}<|im_msg_end|><|im_kimia_assistant_msg_start|> + {%- elif message['role'] == 'assistant' -%} + {{ message['content'] }}<|im_kimia_text_eos|> + {%- endif -%} +{% endfor -%} diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index 50c944e9d..21b940662 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -10,23 +10,6 @@ reasons: import importlib -_CLASS_TO_MODULE: dict[str, str] = { - "BagelProcessor": "vllm.transformers_utils.processors.bagel", - "DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2", - "FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2", - "FunASRProcessor": "vllm.transformers_utils.processors.funasr", - "GLM4VProcessor": "vllm.transformers_utils.processors.glm4v", - "HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl", - "HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image", - "MistralCommonPixtralProcessor": "vllm.transformers_utils.processors.pixtral", - "MistralCommonVoxtralProcessor": "vllm.transformers_utils.processors.voxtral", - "OvisProcessor": "vllm.transformers_utils.processors.ovis", - "Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5", - "QwenVLProcessor": "vllm.transformers_utils.processors.qwen_vl", - "Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr", -} - - __all__ = [ "BagelProcessor", "DeepseekVLV2Processor", @@ -35,6 +18,7 @@ __all__ = [ "GLM4VProcessor", "HunYuanVLProcessor", "HunYuanVLImageProcessor", + "KimiAudioProcessor", "MistralCommonPixtralProcessor", "MistralCommonVoxtralProcessor", "OvisProcessor", @@ -43,6 +27,23 @@ __all__ = [ "Qwen3ASRProcessor", ] +_CLASS_TO_MODULE: dict[str, str] = { + "BagelProcessor": "vllm.transformers_utils.processors.bagel", + "DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2", + "FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2", + "FunASRProcessor": "vllm.transformers_utils.processors.funasr", + "GLM4VProcessor": "vllm.transformers_utils.processors.glm4v", + "HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl", + "HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image", + "KimiAudioProcessor": "vllm.transformers_utils.processors.kimi_audio", + "MistralCommonPixtralProcessor": "vllm.transformers_utils.processors.pixtral", + "MistralCommonVoxtralProcessor": "vllm.transformers_utils.processors.voxtral", + "OvisProcessor": "vllm.transformers_utils.processors.ovis", + "Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5", + "QwenVLProcessor": "vllm.transformers_utils.processors.qwen_vl", + "Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr", +} + def __getattr__(name: str): if name in _CLASS_TO_MODULE: diff --git a/vllm/transformers_utils/processors/kimi_audio.py b/vllm/transformers_utils/processors/kimi_audio.py new file mode 100644 index 000000000..614fdf4fe --- /dev/null +++ b/vllm/transformers_utils/processors/kimi_audio.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# ruff: noqa +# mypy: ignore-errors +# coding=utf-8 +# Copyright 2026 The Moonshot AI team and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Processor for Kimi-Audio ASR model.""" + +from collections.abc import Mapping +from typing import Any + +import numpy as np +import torch +from transformers import AutoFeatureExtractor, BatchFeature, ProcessorMixin +from transformers.audio_utils import AudioInput +from transformers.tokenization_utils_base import TextInput + +from vllm.tokenizers.kimi_audio import KimiAudioTokenizer + + +def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor: + """Compute output lengths after Whisper feature extraction.""" + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ( + ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + ) + return output_lengths + + +class KimiAudioProcessor(ProcessorMixin): + r""" + Constructs a Kimi-Audio processor. + + [`KimiAudioProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and a tokenizer. + See the [`~KimiAudioProcessor.__call__`] and [`~KimiAudioProcessor.decode`] for more information. + + Args: + feature_extractor ([`WhisperFeatureExtractor`], *optional*): + The audio feature extractor. + tokenizer ([`PreTrainedTokenizer`], *optional*): + The text tokenizer. + """ + + # Required for ProcessorMixin + attributes = ["feature_extractor", "tokenizer"] + feature_extractor_class = "AutoFeatureExtractor" + tokenizer_class = "AutoTokenizer" + + # Special token IDs + KIMIA_MEDIA_BEGIN: int = 151661 + KIMIA_MEDIA_END: int = 151663 + KIMIA_TEXT_BLANK: int = 151666 + + # Audio processing constants + AUDIO_SEQ_LEN: int = 376 + + def __init__(self, feature_extractor=None, tokenizer=None, **kwargs): + # Pass feature_extractor and tokenizer to parent ProcessorMixin + super().__init__( + feature_extractor=feature_extractor, + tokenizer=tokenizer, + **kwargs, + ) + + def check_argument_for_proper_class(self, attribute_name: str, argument: Any): + """Override to skip class validation for custom tokenizer.""" + # Skip validation for tokenizer since KimiAudioTokenizer doesn't inherit + # from PreTrainedTokenizerBase but is compatible + if attribute_name == "tokenizer" and argument is not None: + return + # For other attributes, use default validation + super().check_argument_for_proper_class(attribute_name, argument) + + def __call__( + self, + text: TextInput = None, + audio: AudioInput = None, + return_tensors: str = "pt", + **kwargs, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and audio(s). + + Args: + text (`str`, `List[str]`): + The sequence or batch of sequences to be encoded. + audio (`np.ndarray`, `List[np.ndarray]`): + The audio or batch of audio to be prepared. Each audio can be a NumPy array. + return_tensors (`str`): + The type of tensors to return ("pt", "np", etc.) + """ + if text is None: + raise ValueError("You need to specify either a `text` input to process.") + + # Process audio if provided + if audio is not None: + # Ensure audio is a list + if isinstance(audio, np.ndarray): + audio = [audio] + + # Pad audio to hop length (required by WhisperFeatureExtractor) + hop_length = self.feature_extractor.hop_length + padded_audio = [] + for aud in audio: + length = aud.shape[-1] + if length % hop_length != 0: + pad_length = hop_length - (length % hop_length) + aud = np.pad( + aud, (0, pad_length), mode="constant", constant_values=0 + ) + padded_audio.append(aud) + + # Use feature_extractor directly like Qwen3ASR does + audio_inputs = self.feature_extractor( + padded_audio, + sampling_rate=16000, + padding=True, + return_attention_mask=True, + return_tensors=return_tensors, + ) + # Rename to match Kimi-Audio expectations + if "input_features" in audio_inputs: + audio_inputs["whisper_input_features"] = audio_inputs.pop( + "input_features" + ) + if "attention_mask" in audio_inputs: + audio_inputs["feature_attention_mask"] = audio_inputs.pop( + "attention_mask" + ) + else: + audio_inputs = {} + + # Handle text input - can be string or token IDs from vLLM processor + if isinstance(text, list) and len(text) > 0 and isinstance(text[0], int): + # Text is already token IDs (from vLLM processor) - just wrap + text_inputs = {"input_ids": torch.tensor([text], dtype=torch.long)} + else: + # Text is string - tokenize + if not isinstance(text, list): + text = [text] + + text_inputs = self.tokenizer( + text, return_tensors=return_tensors, padding=True + ) + + return BatchFeature( + data={**text_inputs, **audio_inputs}, + tensor_type=return_tensors, + )