[Misc] Clean up renderers (#36770)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -10,11 +10,13 @@ from typing import Any, ClassVar, Literal
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
from transformers import BatchFeature
|
||||
from transformers import WhisperConfig as HFWhisperConfig
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@@ -47,7 +49,10 @@ from vllm.multimodal.processing import (
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
)
|
||||
from vllm.multimodal.processing.processor import BaseMultiModalProcessor
|
||||
from vllm.multimodal.processing.processor import (
|
||||
BaseMultiModalProcessor,
|
||||
ProcessorInputs,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
|
||||
@@ -59,6 +64,15 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
|
||||
|
||||
|
||||
def _get_whisper_local_path(repo_id: str):
|
||||
if os.path.exists(repo_id):
|
||||
repo_local_path = repo_id
|
||||
else:
|
||||
repo_local_path = snapshot_download(repo_id, local_files_only=True)
|
||||
|
||||
return os.path.join(repo_local_path, KIMIA_WHISPER_SUBFOLDER)
|
||||
|
||||
|
||||
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute output lengths after Whisper feature extraction.
|
||||
|
||||
@@ -88,10 +102,10 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
|
||||
# Load Whisper config from subfolder (authoritative source)
|
||||
# Kimi-Audio stores Whisper config in whisper-large-v3/config.json
|
||||
model_path = vllm_config.model_config.model
|
||||
whisper_config_path = os.path.join(model_path, KIMIA_WHISPER_SUBFOLDER)
|
||||
|
||||
# Load WhisperConfig from the subfolder
|
||||
whisper_config = HFWhisperConfig.from_pretrained(whisper_config_path)
|
||||
whisper_dir = _get_whisper_local_path(model_path)
|
||||
whisper_config = HFWhisperConfig.from_pretrained(whisper_dir)
|
||||
|
||||
# Temporarily replace hf_config for WhisperEncoder.__init__()
|
||||
original_config = vllm_config.model_config.hf_config
|
||||
@@ -114,28 +128,18 @@ class KimiAudioWhisperEncoder(WhisperEncoder):
|
||||
class KimiAudioProcessingInfo(BaseProcessingInfo):
|
||||
"""Processing info for vLLM registry."""
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.model_config.hf_config
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> KimiAudioProcessor:
|
||||
"""Get KimiAudioProcessor with feature extractor and tokenizer."""
|
||||
# Use vLLM's cached loader for feature extractor
|
||||
feature_extractor = cached_feature_extractor_from_config(
|
||||
self.ctx.model_config,
|
||||
subfolder=KIMIA_WHISPER_SUBFOLDER,
|
||||
)
|
||||
|
||||
# Use vLLM's standard tokenizer loading (respects tokenizer_mode)
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
# Construct processor directly
|
||||
return KimiAudioProcessor(
|
||||
feature_extractor=feature_extractor,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=self.get_tokenizer(),
|
||||
)
|
||||
|
||||
def get_feature_extractor(self, **kwargs: object):
|
||||
"""Get feature extractor using vLLM's cached loader."""
|
||||
return cached_feature_extractor_from_config(
|
||||
self.ctx.model_config, subfolder=KIMIA_WHISPER_SUBFOLDER
|
||||
)
|
||||
@@ -144,26 +148,16 @@ class KimiAudioProcessingInfo(BaseProcessingInfo):
|
||||
return {"audio": 1}
|
||||
|
||||
def get_data_parser(self) -> "KimiAudioMultiModalDataParser":
|
||||
"""Get data parser for audio inputs."""
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
return KimiAudioMultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
|
||||
|
||||
class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo]):
|
||||
"""Dummy inputs builder for vLLM registry."""
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]:
|
||||
"""Return dummy text as token IDs directly."""
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
if num_audios == 0:
|
||||
return [198] # "Transcribe" tokenized
|
||||
# Return as token IDs directly to avoid tokenizer issues
|
||||
return [
|
||||
KimiAudioProcessor.KIMIA_MEDIA_BEGIN,
|
||||
KimiAudioProcessor.KIMIA_TEXT_BLANK,
|
||||
KimiAudioProcessor.KIMIA_MEDIA_END,
|
||||
] * num_audios
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
@@ -186,6 +180,29 @@ class KimiAudioDummyInputsBuilder(BaseDummyInputsBuilder[KimiAudioProcessingInfo
|
||||
),
|
||||
}
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions],
|
||||
) -> ProcessorInputs:
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
|
||||
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
dummy_tokens = (
|
||||
[198]
|
||||
if num_audios == 0
|
||||
else [
|
||||
KimiAudioProcessor.KIMIA_MEDIA_BEGIN,
|
||||
KimiAudioProcessor.KIMIA_TEXT_BLANK,
|
||||
KimiAudioProcessor.KIMIA_MEDIA_END,
|
||||
]
|
||||
* num_audios
|
||||
)
|
||||
|
||||
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
|
||||
|
||||
|
||||
# Field config for Kimi-Audio multimodal data
|
||||
_KIMIAUDIO_FIELD_CONFIG = {
|
||||
@@ -197,10 +214,6 @@ _KIMIAUDIO_FIELD_CONFIG = {
|
||||
class KimiAudioMultiModalDataParser(MultiModalDataParser):
|
||||
"""Custom data parser for Kimi-Audio multimodal data."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Whisper expects 16kHz audio
|
||||
super().__init__(target_sr=16000, **kwargs)
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
|
||||
@@ -589,9 +602,8 @@ class KimiAudioForConditionalGeneration(
|
||||
loaded = loader.load_weights(main_weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
# Load Whisper encoder weights from subfolder
|
||||
whisper_path = os.path.join(
|
||||
self.model_path, f"{KIMIA_WHISPER_SUBFOLDER}/model.safetensors"
|
||||
)
|
||||
whisper_dir = _get_whisper_local_path(self.model_path)
|
||||
whisper_path = os.path.join(whisper_dir, "model.safetensors")
|
||||
if os.path.exists(whisper_path):
|
||||
whisper_loaded = self._load_whisper_weights_from_file(whisper_path)
|
||||
loaded.update(whisper_loaded)
|
||||
|
||||
@@ -63,12 +63,10 @@ from vllm.multimodal.processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
InputProcessingContext,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@@ -546,9 +544,6 @@ class Llama4VisionModel(nn.Module):
|
||||
|
||||
|
||||
class Mllama4ProcessingInfo(BaseProcessingInfo):
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
super().__init__(ctx)
|
||||
|
||||
def get_hf_config(self) -> Llama4Config:
|
||||
return self.ctx.get_hf_config(Llama4Config)
|
||||
|
||||
@@ -557,9 +552,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
|
||||
Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
|
||||
)
|
||||
|
||||
def get_default_tok_params(self) -> TokenizeParams:
|
||||
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
# Although vLLM can support more images from an infra capability
|
||||
# perspective, we do not recommend using >10 images in practice.
|
||||
@@ -597,10 +589,6 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo])
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
if mm_data is None:
|
||||
return tokenizer(prompt, add_special_tokens=False) # exclude bos
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
|
||||
@@ -172,12 +172,20 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions],
|
||||
mm_data: MultiModalDataDict | None = None,
|
||||
) -> ProcessorInputs:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
dummy_images = dummy_mm_data.get("image", [])
|
||||
dummy_mm_data = (
|
||||
self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
if mm_data is None
|
||||
else mm_data
|
||||
)
|
||||
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
|
||||
dummy_images = (
|
||||
[] if "image" not in dummy_mm_data else dummy_mm_items["image"].get_all()
|
||||
)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
messages=[
|
||||
@@ -192,8 +200,6 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
dummy_tokens = res.tokens
|
||||
|
||||
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
|
||||
|
||||
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
|
||||
|
||||
|
||||
|
||||
@@ -150,13 +150,21 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions],
|
||||
mm_data: MultiModalDataDict | None = None,
|
||||
) -> ProcessorInputs:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
feature_extractor = self.info.get_hf_processor().feature_extractor
|
||||
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
dummy_audios = dummy_mm_data.get("audio", [])
|
||||
dummy_mm_data = (
|
||||
self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
if mm_data is None
|
||||
else mm_data
|
||||
)
|
||||
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
|
||||
dummy_audios = (
|
||||
[] if "audio" not in dummy_mm_data else dummy_mm_items["audio"].get_all()
|
||||
)
|
||||
|
||||
audio_chunks: list[AudioChunk] = []
|
||||
format = "wav"
|
||||
|
||||
Reference in New Issue
Block a user