[Misc] Clean up renderers (#36770)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-03-12 00:39:29 +08:00
committed by GitHub
parent c84b519cf3
commit 196802dfa6
12 changed files with 136 additions and 220 deletions

View File

@@ -10,11 +10,13 @@ from typing import Any, ClassVar, Literal
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import BatchFeature
from transformers import WhisperConfig as HFWhisperConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader.weight_utils import (
@@ -47,7 +49,10 @@ from vllm.multimodal.processing import (
BaseProcessingInfo,
PromptReplacement,
)
from vllm.multimodal.processing.processor import BaseMultiModalProcessor
from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor,
ProcessorInputs,
)
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
@@ -59,6 +64,15 @@ from vllm.v1.sample.metadata import SamplingMetadata
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
def _get_whisper_local_path(repo_id: str):
if os.path.exists(repo_id):
repo_local_path = repo_id
else:
repo_local_path = snapshot_download(repo_id, local_files_only=True)
return os.path.join(repo_local_path, KIMIA_WHISPER_SUBFOLDER)
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
"""Compute output lengths after Whisper feature extraction.
@@ -88,10 +102,10 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
# Load Whisper config from subfolder (authoritative source)
# Kimi-Audio stores Whisper config in whisper-large-v3/config.json
model_path = vllm_config.model_config.model
whisper_config_path = os.path.join(model_path, KIMIA_WHISPER_SUBFOLDER)
# Load WhisperConfig from the subfolder
whisper_config = HFWhisperConfig.from_pretrained(whisper_config_path)
whisper_dir = _get_whisper_local_path(model_path)
whisper_config = HFWhisperConfig.from_pretrained(whisper_dir)
# Temporarily replace hf_config for WhisperEncoder.__init__()
original_config = vllm_config.model_config.hf_config
@@ -114,28 +128,18 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
class KimiAudioProcessingInfo(BaseProcessingInfo):
"""Processing info for vLLM registry."""
def get_hf_config(self):
return self.ctx.model_config.hf_config
def get_hf_processor(self, **kwargs: object) -> KimiAudioProcessor:
"""Get KimiAudioProcessor with feature extractor and tokenizer."""
# Use vLLM's cached loader for feature extractor
feature_extractor = cached_feature_extractor_from_config(
self.ctx.model_config,
subfolder=KIMIA_WHISPER_SUBFOLDER,
)
# Use vLLM's standard tokenizer loading (respects tokenizer_mode)
tokenizer = self.get_tokenizer()
# Construct processor directly
return KimiAudioProcessor(
feature_extractor=feature_extractor,
tokenizer=tokenizer,
tokenizer=self.get_tokenizer(),
)
def get_feature_extractor(self, **kwargs: object):
"""Get feature extractor using vLLM's cached loader."""
return cached_feature_extractor_from_config(
self.ctx.model_config, subfolder=KIMIA_WHISPER_SUBFOLDER
)
@@ -144,26 +148,16 @@ class KimiAudioProcessingInfo(BaseProcessingInfo):
return {"audio": 1}
def get_data_parser(self) -> "KimiAudioMultiModalDataParser":
"""Get data parser for audio inputs."""
feature_extractor = self.get_feature_extractor()
return KimiAudioMultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
expected_hidden_size=self._get_expected_hidden_size(),
)
class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo]):
"""Dummy inputs builder for vLLM registry."""
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]:
"""Return dummy text as token IDs directly."""
num_audios = mm_counts.get("audio", 0)
if num_audios == 0:
return [198] # "Transcribe" tokenized
# Return as token IDs directly to avoid tokenizer issues
return [
KimiAudioProcessor.KIMIA_MEDIA_BEGIN,
KimiAudioProcessor.KIMIA_TEXT_BLANK,
KimiAudioProcessor.KIMIA_MEDIA_END,
] * num_audios
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
@@ -186,6 +180,29 @@ class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo
),
}
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions],
) -> ProcessorInputs:
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
num_audios = mm_counts.get("audio", 0)
dummy_tokens = (
[198]
if num_audios == 0
else [
KimiAudioProcessor.KIMIA_MEDIA_BEGIN,
KimiAudioProcessor.KIMIA_TEXT_BLANK,
KimiAudioProcessor.KIMIA_MEDIA_END,
]
* num_audios
)
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
# Field config for Kimi-Audio multimodal data
_KIMIAUDIO_FIELD_CONFIG = {
@@ -197,10 +214,6 @@ _KIMIAUDIO_FIELD_CONFIG = {
class KimiAudioMultiModalDataParser(MultiModalDataParser):
"""Custom data parser for Kimi-Audio multimodal data."""
def __init__(self, **kwargs):
# Whisper expects 16kHz audio
super().__init__(target_sr=16000, **kwargs)
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
@@ -589,9 +602,8 @@ class KimiAudioForConditionalGeneration(
loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper)
# Load Whisper encoder weights from subfolder
whisper_path = os.path.join(
self.model_path, f"{KIMIA_WHISPER_SUBFOLDER}/model.safetensors"
)
whisper_dir = _get_whisper_local_path(self.model_path)
whisper_path = os.path.join(whisper_dir, "model.safetensors")
if os.path.exists(whisper_path):
whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
loaded.update(whisper_loaded)

View File

@@ -63,12 +63,10 @@ from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -546,9 +544,6 @@ class Llama4VisionModel(nn.Module):
class Mllama4ProcessingInfo(BaseProcessingInfo):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(ctx)
def get_hf_config(self) -> Llama4Config:
return self.ctx.get_hf_config(Llama4Config)
@@ -557,9 +552,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
)
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
# Although vLLM can support more images from an infra capability
# perspective, we do not recommend using >10 images in practice.
@@ -597,10 +589,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo])
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if mm_data is None:
return tokenizer(prompt, add_special_tokens=False) # exclude bos
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,

View File

@@ -172,12 +172,20 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions],
mm_data: MultiModalDataDict | None = None,
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_images = dummy_mm_data.get("image", [])
dummy_mm_data = (
self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
if mm_data is None
else mm_data
)
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
dummy_images = (
[] if "image" not in dummy_mm_data else dummy_mm_items["image"].get_all()
)
request = ChatCompletionRequest(
messages=[
@@ -192,8 +200,6 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)

View File

@@ -150,13 +150,21 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions],
mm_data: MultiModalDataDict | None = None,
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
feature_extractor = self.info.get_hf_processor().feature_extractor
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_audios = dummy_mm_data.get("audio", [])
dummy_mm_data = (
self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
if mm_data is None
else mm_data
)
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
dummy_audios = (
[] if "audio" not in dummy_mm_data else dummy_mm_items["audio"].get_all()
)
audio_chunks: list[AudioChunk] = []
format = "wav"