[Misc] Clean up renderers (#36770)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -6,9 +6,6 @@ from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
@@ -21,7 +18,10 @@ from vllm.config.multimodal import (
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import MultiModalInputs, batched_tensors_equal
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
InputProcessingContext,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
@@ -74,20 +74,6 @@ def glmasr_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
|
||||
return mm_data
|
||||
|
||||
|
||||
# For some multimodal models, tokenizer will always add bos_token
|
||||
# at the beginning of prompt by default, causing hf_processor outputs
|
||||
# incorrect token ids. So we need use `add_special_tokens=False` here
|
||||
# to leave bos_token to be added by the processor.
|
||||
_ADD_SPECIAL_TOKENS_OVERRIDES = {
|
||||
"lfm2_vl": False,
|
||||
"nemotron_parse": False,
|
||||
"ovis": False,
|
||||
"ovis2_5": False,
|
||||
"paligemma": False,
|
||||
"ultravox": False,
|
||||
"whisper": False,
|
||||
}
|
||||
|
||||
_IGNORE_MM_KEYS = {
|
||||
# In Ultravox, the audio_features can be different depending on padding
|
||||
# The slight difference should not be a problem though, since
|
||||
@@ -152,63 +138,34 @@ def get_text_token_prompts(
|
||||
parsed_data = processor.info.parse_mm_data(mm_data)
|
||||
mm_counts = {k: len(vs) for k, vs in parsed_data.items()}
|
||||
|
||||
text_prompt: str | None
|
||||
token_prompt: list[int]
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
# ChatCompletionRequest only supports ImageChunk natively;
|
||||
# for other modalities (e.g. audio), fall back to the model's
|
||||
# own dummy inputs builder which knows the right placeholders.
|
||||
has_non_image = any(
|
||||
k != "image" and count > 0 for k, count in mm_counts.items()
|
||||
inputs = dummy_inputs.get_dummy_processor_inputs(
|
||||
model_config.max_model_len,
|
||||
mm_counts,
|
||||
mm_options={},
|
||||
# Assume all Mistral models define this extra argument
|
||||
mm_data=mm_data, # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
if has_non_image:
|
||||
inputs = dummy_inputs.get_dummy_processor_inputs(
|
||||
model_config.max_model_len,
|
||||
mm_counts,
|
||||
mm_options={},
|
||||
)
|
||||
text_prompt = None
|
||||
token_prompt = (
|
||||
inputs.prompt
|
||||
if isinstance(inputs.prompt, list)
|
||||
else tokenizer.encode(inputs.prompt, add_special_tokens=False)
|
||||
)
|
||||
else:
|
||||
images = parsed_data.get("image", [])
|
||||
request = ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content=[
|
||||
TextChunk(text=""),
|
||||
*(ImageChunk(image=image) for image in images),
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
|
||||
# Mistral does not support decode_tokens with
|
||||
# skip_special_tokens=False
|
||||
text_prompt = None
|
||||
token_prompt = res.tokens
|
||||
else:
|
||||
inputs = dummy_inputs.get_dummy_processor_inputs(
|
||||
model_config.max_model_len,
|
||||
mm_counts,
|
||||
mm_options={},
|
||||
)
|
||||
# Some models (e.g., Kimi-Audio) return token IDs directly instead of str
|
||||
if isinstance(inputs.prompt, list):
|
||||
text_prompt = None
|
||||
token_prompt = inputs.prompt
|
||||
else:
|
||||
assert isinstance(inputs.prompt, str)
|
||||
text_prompt = inputs.prompt
|
||||
token_prompt = tokenizer.encode(
|
||||
text_prompt,
|
||||
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type, True),
|
||||
)
|
||||
|
||||
text_prompt: str | None
|
||||
token_prompt: list[int]
|
||||
if isinstance(inputs.prompt, list):
|
||||
text_prompt = None
|
||||
token_prompt = inputs.prompt
|
||||
elif isinstance(inputs.prompt, str):
|
||||
text_prompt = inputs.prompt
|
||||
token_prompt = tokenizer.encode(
|
||||
text_prompt,
|
||||
**processor.info.get_default_tok_params().get_encode_kwargs(),
|
||||
)
|
||||
else:
|
||||
raise TypeError(type(inputs.prompt))
|
||||
|
||||
return text_prompt, token_prompt
|
||||
|
||||
@@ -448,7 +405,7 @@ def test_processing_correctness(
|
||||
)
|
||||
if model_id == "mistralai/Voxtral-Mini-4B-Realtime-2602":
|
||||
pytest.skip(
|
||||
"Voxtral Realtime doesn't make use of any place-holder"
|
||||
"Voxtral Realtime doesn't make use of any place-holder "
|
||||
"tokens and hence cannot pass the processing "
|
||||
"correctness test as is. Let's revisit adapting this "
|
||||
"test once more realtime models exist."
|
||||
|
||||
@@ -532,6 +532,22 @@ class ModelConfig:
|
||||
self._architecture = arch
|
||||
logger.info("Resolved architecture: %s", arch)
|
||||
|
||||
# Set default tokenizer modes based on model architecture
|
||||
if self.tokenizer_mode == "auto":
|
||||
if arch == "Grok1ForCausalLM":
|
||||
self.tokenizer_mode = "grok2"
|
||||
elif arch == "MoonshotKimiaForCausalLM":
|
||||
self.tokenizer_mode = "kimi_audio"
|
||||
elif arch == "QwenVLForConditionalGeneration":
|
||||
self.tokenizer_mode = "qwen_vl"
|
||||
|
||||
if self.tokenizer_mode != "auto":
|
||||
logger.info(
|
||||
"Defaulting to tokenizer_mode=%r for %s",
|
||||
self.tokenizer_mode,
|
||||
arch,
|
||||
)
|
||||
|
||||
# Init pooler config if needed
|
||||
if self.runner_type == "pooling":
|
||||
if self.pooler_config is None:
|
||||
|
||||
@@ -10,11 +10,13 @@ from typing import Any, ClassVar, Literal
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
from transformers import BatchFeature
|
||||
from transformers import WhisperConfig as HFWhisperConfig
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@@ -47,7 +49,10 @@ from vllm.multimodal.processing import (
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
)
|
||||
from vllm.multimodal.processing.processor import BaseMultiModalProcessor
|
||||
from vllm.multimodal.processing.processor import (
|
||||
BaseMultiModalProcessor,
|
||||
ProcessorInputs,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
|
||||
@@ -59,6 +64,15 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
|
||||
|
||||
|
||||
def _get_whisper_local_path(repo_id: str):
|
||||
if os.path.exists(repo_id):
|
||||
repo_local_path = repo_id
|
||||
else:
|
||||
repo_local_path = snapshot_download(repo_id, local_files_only=True)
|
||||
|
||||
return os.path.join(repo_local_path, KIMIA_WHISPER_SUBFOLDER)
|
||||
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute output lengths after Whisper feature extraction.
|
||||
|
||||
@@ -88,10 +102,10 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
|
||||
# Load Whisper config from subfolder (authoritative source)
|
||||
# Kimi-Audio stores Whisper config in whisper-large-v3/config.json
|
||||
model_path = vllm_config.model_config.model
|
||||
whisper_config_path = os.path.join(model_path, KIMIA_WHISPER_SUBFOLDER)
|
||||
|
||||
# Load WhisperConfig from the subfolder
|
||||
whisper_config = HFWhisperConfig.from_pretrained(whisper_config_path)
|
||||
whisper_dir = _get_whisper_local_path(model_path)
|
||||
whisper_config = HFWhisperConfig.from_pretrained(whisper_dir)
|
||||
|
||||
# Temporarily replace hf_config for WhisperEncoder.__init__()
|
||||
original_config = vllm_config.model_config.hf_config
|
||||
@@ -114,28 +128,18 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
|
||||
class KimiAudioProcessingInfo(BaseProcessingInfo):
|
||||
"""Processing info for vLLM registry."""
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.model_config.hf_config
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> KimiAudioProcessor:
|
||||
"""Get KimiAudioProcessor with feature extractor and tokenizer."""
|
||||
# Use vLLM's cached loader for feature extractor
|
||||
feature_extractor = cached_feature_extractor_from_config(
|
||||
self.ctx.model_config,
|
||||
subfolder=KIMIA_WHISPER_SUBFOLDER,
|
||||
)
|
||||
|
||||
# Use vLLM's standard tokenizer loading (respects tokenizer_mode)
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
# Construct processor directly
|
||||
return KimiAudioProcessor(
|
||||
feature_extractor=feature_extractor,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=self.get_tokenizer(),
|
||||
)
|
||||
|
||||
def get_feature_extractor(self, **kwargs: object):
|
||||
"""Get feature extractor using vLLM's cached loader."""
|
||||
return cached_feature_extractor_from_config(
|
||||
self.ctx.model_config, subfolder=KIMIA_WHISPER_SUBFOLDER
|
||||
)
|
||||
@@ -144,26 +148,16 @@ class KimiAudioProcessingInfo(BaseProcessingInfo):
|
||||
return {"audio": 1}
|
||||
|
||||
def get_data_parser(self) -> "KimiAudioMultiModalDataParser":
|
||||
"""Get data parser for audio inputs."""
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
return KimiAudioMultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
|
||||
|
||||
class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo]):
|
||||
"""Dummy inputs builder for vLLM registry."""
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]:
|
||||
"""Return dummy text as token IDs directly."""
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
if num_audios == 0:
|
||||
return [198] # "Transcribe" tokenized
|
||||
# Return as token IDs directly to avoid tokenizer issues
|
||||
return [
|
||||
KimiAudioProcessor.KIMIA_MEDIA_BEGIN,
|
||||
KimiAudioProcessor.KIMIA_TEXT_BLANK,
|
||||
KimiAudioProcessor.KIMIA_MEDIA_END,
|
||||
] * num_audios
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
@@ -186,6 +180,29 @@ class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo
|
||||
),
|
||||
}
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions],
|
||||
) -> ProcessorInputs:
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
|
||||
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
dummy_tokens = (
|
||||
[198]
|
||||
if num_audios == 0
|
||||
else [
|
||||
KimiAudioProcessor.KIMIA_MEDIA_BEGIN,
|
||||
KimiAudioProcessor.KIMIA_TEXT_BLANK,
|
||||
KimiAudioProcessor.KIMIA_MEDIA_END,
|
||||
]
|
||||
* num_audios
|
||||
)
|
||||
|
||||
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
|
||||
|
||||
|
||||
# Field config for Kimi-Audio multimodal data
|
||||
_KIMIAUDIO_FIELD_CONFIG = {
|
||||
@@ -197,10 +214,6 @@ _KIMIAUDIO_FIELD_CONFIG = {
|
||||
class KimiAudioMultiModalDataParser(MultiModalDataParser):
|
||||
"""Custom data parser for Kimi-Audio multimodal data."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Whisper expects 16kHz audio
|
||||
super().__init__(target_sr=16000, **kwargs)
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
|
||||
@@ -589,9 +602,8 @@ class KimiAudioForConditionalGeneration(
|
||||
loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
# Load Whisper encoder weights from subfolder
|
||||
whisper_path = os.path.join(
|
||||
self.model_path, f"{KIMIA_WHISPER_SUBFOLDER}/model.safetensors"
|
||||
)
|
||||
whisper_dir = _get_whisper_local_path(self.model_path)
|
||||
whisper_path = os.path.join(whisper_dir, "model.safetensors")
|
||||
if os.path.exists(whisper_path):
|
||||
whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
|
||||
loaded.update(whisper_loaded)
|
||||
|
||||
@@ -63,12 +63,10 @@ from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@@ -546,9 +544,6 @@ class Llama4VisionModel(nn.Module):
|
||||
|
||||
|
||||
class Mllama4ProcessingInfo(BaseProcessingInfo):
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
super().__init__(ctx)
|
||||
|
||||
def get_hf_config(self) -> Llama4Config:
|
||||
return self.ctx.get_hf_config(Llama4Config)
|
||||
|
||||
@@ -557,9 +552,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
|
||||
Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
|
||||
)
|
||||
|
||||
def get_default_tok_params(self) -> TokenizeParams:
|
||||
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
# Although vLLM can support more images from an infra capability
|
||||
# perspective, we do not recommend using >10 images in practice.
|
||||
@@ -597,10 +589,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo])
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
if mm_data is None:
|
||||
return tokenizer(prompt, add_special_tokens=False) # exclude bos
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
|
||||
@@ -172,12 +172,20 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions],
|
||||
mm_data: MultiModalDataDict | None = None,
|
||||
) -> ProcessorInputs:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
dummy_images = dummy_mm_data.get("image", [])
|
||||
dummy_mm_data = (
|
||||
self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
if mm_data is None
|
||||
else mm_data
|
||||
)
|
||||
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
|
||||
dummy_images = (
|
||||
[] if "image" not in dummy_mm_data else dummy_mm_items["image"].get_all()
|
||||
)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
messages=[
|
||||
@@ -192,8 +200,6 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
dummy_tokens = res.tokens
|
||||
|
||||
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
|
||||
|
||||
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
|
||||
|
||||
|
||||
|
||||
@@ -150,13 +150,21 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions],
|
||||
mm_data: MultiModalDataDict | None = None,
|
||||
) -> ProcessorInputs:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
feature_extractor = self.info.get_hf_processor().feature_extractor
|
||||
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
dummy_audios = dummy_mm_data.get("audio", [])
|
||||
dummy_mm_data = (
|
||||
self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
if mm_data is None
|
||||
else mm_data
|
||||
)
|
||||
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
|
||||
dummy_audios = (
|
||||
[] if "audio" not in dummy_mm_data else dummy_mm_items["audio"].get_all()
|
||||
)
|
||||
|
||||
audio_chunks: list[AudioChunk] = []
|
||||
format = "wav"
|
||||
|
||||
@@ -6,11 +6,10 @@ from vllm.config import VllmConfig
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.qwen_vl import QwenVLTokenizer
|
||||
|
||||
from .base import BaseRenderer
|
||||
from .hf import HfRenderer
|
||||
|
||||
|
||||
class QwenVLRenderer(BaseRenderer[QwenVLTokenizer]):
|
||||
class QwenVLRenderer(HfRenderer):
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
|
||||
@@ -80,13 +80,6 @@ def renderer_from_config(config: "VllmConfig", **kwargs):
|
||||
model_config, **kwargs
|
||||
)
|
||||
|
||||
# Override tokenizer_mode for Kimi-Audio models
|
||||
if model_config.architecture == "MoonshotKimiaForCausalLM":
|
||||
tokenizer_mode = "kimi_audio"
|
||||
# Update model_config so other components (e.g., multimodal registry)
|
||||
# also use the correct tokenizer mode
|
||||
model_config.tokenizer_mode = "kimi_audio"
|
||||
|
||||
if (
|
||||
model_config.tokenizer_mode == "auto"
|
||||
and model_config.model_impl == "terratorch"
|
||||
|
||||
@@ -159,18 +159,6 @@ def resolve_tokenizer_args(
|
||||
):
|
||||
tokenizer_mode = "mistral"
|
||||
|
||||
# Try to use Grok2 tiktoken tokenizer if possible
|
||||
if tokenizer_mode == "auto" and any_pattern_in_repo_files(
|
||||
model_name_or_path=str(tokenizer_name),
|
||||
allow_patterns=["tokenizer.tok.json"],
|
||||
revision=revision,
|
||||
):
|
||||
tokenizer_mode = "grok2"
|
||||
|
||||
# Model-specific tokenizers
|
||||
if tokenizer_mode == "auto" and "/Qwen-VL" in str(tokenizer_name):
|
||||
tokenizer_mode = "qwen_vl"
|
||||
|
||||
# Fallback to HF tokenizer
|
||||
if tokenizer_mode == "auto":
|
||||
tokenizer_mode = "hf"
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/zai-org/CogAgent
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.image_processing_utils_fast import BaseImageProcessorFast
|
||||
from transformers.image_utils import PILImageResampling
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ruff: noqa
|
||||
# mypy: ignore-errors
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Moonshot AI team and the HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2026 The Moonshot AI team and the HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -19,42 +17,13 @@
|
||||
# limitations under the License.
|
||||
"""Processor for Kimi-Audio ASR model."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoFeatureExtractor, BatchFeature, ProcessorMixin
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
from transformers.audio_utils import AudioInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
|
||||
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute output lengths after Whisper feature extraction."""
|
||||
input_lengths_leave = input_lengths % 100
|
||||
feat_lengths = (input_lengths_leave - 1) // 2 + 1
|
||||
output_lengths = (
|
||||
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
|
||||
)
|
||||
return output_lengths
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class KimiAudioProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Kimi-Audio processor.
|
||||
|
||||
[`KimiAudioProcessor`] offers all the functionalities of [`WhisperFeatureExtractor`], and a tokenizer.
|
||||
See the [`~KimiAudioProcessor.__call__`] and [`~KimiAudioProcessor.decode`] for more information.
|
||||
|
||||
Args:
|
||||
feature_extractor ([`WhisperFeatureExtractor`], *optional*):
|
||||
The audio feature extractor.
|
||||
tokenizer ([`PreTrainedTokenizer`], *optional*):
|
||||
The text tokenizer.
|
||||
"""
|
||||
|
||||
# Required for ProcessorMixin
|
||||
attributes = ["feature_extractor", "tokenizer"]
|
||||
feature_extractor_class = "AutoFeatureExtractor"
|
||||
@@ -69,44 +38,30 @@ class KimiAudioProcessor(ProcessorMixin):
|
||||
AUDIO_SEQ_LEN: int = 376
|
||||
|
||||
def __init__(self, feature_extractor=None, tokenizer=None, **kwargs):
|
||||
# Pass feature_extractor and tokenizer to parent ProcessorMixin
|
||||
super().__init__(
|
||||
feature_extractor=feature_extractor,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def check_argument_for_proper_class(self, attribute_name: str, argument: Any):
|
||||
"""Override to skip class validation for custom tokenizer."""
|
||||
# Skip validation for tokenizer since KimiAudioTokenizer doesn't inherit
|
||||
# from PreTrainedTokenizerBase but is compatible
|
||||
if attribute_name == "tokenizer" and argument is not None:
|
||||
return
|
||||
# For other attributes, use default validation
|
||||
super().check_argument_for_proper_class(attribute_name, argument)
|
||||
self.feature_extractor = feature_extractor
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput = None,
|
||||
audio: AudioInput = None,
|
||||
text: TextInput
|
||||
| PreTokenizedInput
|
||||
| list[TextInput]
|
||||
| list[PreTokenizedInput]
|
||||
| None = None,
|
||||
audio: AudioInput | None = None,
|
||||
return_tensors: str = "pt",
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and audio(s).
|
||||
if text is not None:
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`):
|
||||
The sequence or batch of sequences to be encoded.
|
||||
audio (`np.ndarray`, `List[np.ndarray]`):
|
||||
The audio or batch of audio to be prepared. Each audio can be a NumPy array.
|
||||
return_tensors (`str`):
|
||||
The type of tensors to return ("pt", "np", etc.)
|
||||
"""
|
||||
if text is None:
|
||||
raise ValueError("You need to specify either a `text` input to process.")
|
||||
text_inputs = self.tokenizer(
|
||||
text, return_tensors=return_tensors, padding=True
|
||||
)
|
||||
else:
|
||||
text_inputs = {}
|
||||
|
||||
# Process audio if provided
|
||||
if audio is not None:
|
||||
# Ensure audio is a list
|
||||
if isinstance(audio, np.ndarray):
|
||||
@@ -144,19 +99,6 @@ class KimiAudioProcessor(ProcessorMixin):
|
||||
else:
|
||||
audio_inputs = {}
|
||||
|
||||
# Handle text input - can be string or token IDs from vLLM processor
|
||||
if isinstance(text, list) and len(text) > 0 and isinstance(text[0], int):
|
||||
# Text is already token IDs (from vLLM processor) - just wrap
|
||||
text_inputs = {"input_ids": torch.tensor([text], dtype=torch.long)}
|
||||
else:
|
||||
# Text is string - tokenize
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
text, return_tensors=return_tensors, padding=True
|
||||
)
|
||||
|
||||
return BatchFeature(
|
||||
data={**text_inputs, **audio_inputs},
|
||||
tensor_type=return_tensors,
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://huggingface.co/Qwen/Qwen-VL/blob/main/modeling_qwen.py
|
||||
# Copyright (c) Alibaba Cloud.
|
||||
from transformers.image_processing_utils_fast import BaseImageProcessorFast
|
||||
from transformers.image_utils import PILImageResampling
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
Reference in New Issue
Block a user