[Frontend] Move warmup into Renderer (#36482)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -72,6 +72,7 @@ from vllm.logprobs import Logprob
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.parser import ParserManager
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.renderers import ChatParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser
|
||||
@@ -171,44 +172,14 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self.supports_code_interpreter = False
|
||||
self.python_tool = None
|
||||
|
||||
async def warmup(self) -> None:
|
||||
"""
|
||||
Warm up the chat template processing to avoid first-request latency.
|
||||
|
||||
This method triggers Jinja2 template compilation and content format
|
||||
detection that would otherwise happen on the first real request,
|
||||
causing increased latency on the first request.
|
||||
"""
|
||||
logger.info("Warming up chat template processing...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Create a minimal dummy request
|
||||
dummy_request = ChatCompletionRequest(
|
||||
messages=[{"role": "user", "content": "warmup"}],
|
||||
model=None,
|
||||
max_completion_tokens=1,
|
||||
def warmup(self) -> None:
|
||||
self.renderer.warmup(
|
||||
ChatParams(
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
chat_template_kwargs=self.default_chat_template_kwargs,
|
||||
)
|
||||
|
||||
# Call _preprocess_chat to trigger template compilation
|
||||
# This forces:
|
||||
# 1. Chat template content format detection
|
||||
# 2. Jinja2 template compilation
|
||||
# 3. Tokenizer initialization for chat
|
||||
await self._preprocess_chat(
|
||||
dummy_request,
|
||||
dummy_request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=self.default_chat_template_kwargs,
|
||||
)
|
||||
|
||||
elapsed = (time.perf_counter() - start_time) * 1000
|
||||
logger.info("Chat template warmup completed in %.1fms", elapsed)
|
||||
|
||||
except Exception:
|
||||
# Log but don't fail server startup if warmup fails
|
||||
logger.exception("Chat template warmup failed")
|
||||
)
|
||||
|
||||
async def render_chat_request(
|
||||
self,
|
||||
|
||||
@@ -114,9 +114,8 @@ async def init_generate_state(
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
# Warm up chat template processing to avoid first-request latency
|
||||
if state.openai_serving_chat is not None:
|
||||
await state.openai_serving_chat.warmup()
|
||||
state.openai_serving_chat.warmup()
|
||||
state.openai_serving_completion = (
|
||||
OpenAIServingCompletion(
|
||||
engine_client,
|
||||
|
||||
@@ -42,10 +42,7 @@ from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import EncoderDecoderInputs, ProcessorInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import FlatLogprobs, Logprob
|
||||
from vllm.model_executor.models import (
|
||||
SupportsTranscription,
|
||||
supports_transcription,
|
||||
)
|
||||
from vllm.model_executor.models import SupportsTranscription
|
||||
from vllm.multimodal.audio import split_audio
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
|
||||
@@ -132,121 +129,6 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
# Warm up audio preprocessing to avoid first-request latency
|
||||
self._warmup_audio_preprocessing()
|
||||
# Warm up input processor with dummy audio
|
||||
self._warmup_input_processor()
|
||||
|
||||
def _warmup_audio_preprocessing(self) -> None:
|
||||
"""Warm up audio processing libraries to avoid first-request latency.
|
||||
|
||||
The first call to librosa functions (load, get_duration, mel-spectrogram)
|
||||
triggers JIT compilation and library initialization which can take ~7s.
|
||||
This method warms up these operations during server initialization.
|
||||
"""
|
||||
# Skip warmup if librosa is not installed (optional dependency)
|
||||
if isinstance(librosa, PlaceholderModule):
|
||||
return
|
||||
|
||||
# Skip warmup if model doesn't support transcription
|
||||
if not supports_transcription(self.model_cls):
|
||||
return
|
||||
|
||||
if getattr(self.model_cls, "skip_warmup_audio_preprocessing", False):
|
||||
return
|
||||
|
||||
try:
|
||||
warmup_start = time.perf_counter()
|
||||
logger.info("Warming up audio preprocessing libraries...")
|
||||
|
||||
# Create a minimal dummy audio (1 second of silence at target sample rate)
|
||||
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
|
||||
|
||||
# Warm up librosa.load by using librosa functions on the dummy data
|
||||
# This initializes FFTW, numba JIT, and other audio processing libraries
|
||||
_ = librosa.get_duration(y=dummy_audio, sr=self.asr_config.sample_rate)
|
||||
|
||||
# Warm up mel-spectrogram computation with model-specific parameters
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
|
||||
processor = cached_processor_from_config(self.model_config)
|
||||
feature_extractor = None
|
||||
if hasattr(processor, "feature_extractor"):
|
||||
feature_extractor = processor.feature_extractor
|
||||
elif hasattr(processor, "audio_processor"):
|
||||
# For models like GraniteSpeech that use audio_processor
|
||||
audio_proc = processor.audio_processor
|
||||
if hasattr(audio_proc, "feature_extractor"):
|
||||
feature_extractor = audio_proc.feature_extractor
|
||||
# If audio_processor doesn't have feature_extractor,
|
||||
# skip mel-spectrogram warmup for these models
|
||||
|
||||
if feature_extractor is not None:
|
||||
_ = librosa.feature.melspectrogram(
|
||||
y=dummy_audio,
|
||||
sr=self.asr_config.sample_rate,
|
||||
n_mels=getattr(feature_extractor, "n_mels", 128),
|
||||
n_fft=getattr(feature_extractor, "n_fft", 400),
|
||||
hop_length=getattr(feature_extractor, "hop_length", 160),
|
||||
)
|
||||
|
||||
warmup_elapsed = time.perf_counter() - warmup_start
|
||||
logger.info("Audio preprocessing warmup completed in %.2fs", warmup_elapsed)
|
||||
except Exception:
|
||||
# Don't fail initialization if warmup fails - log exception and continue
|
||||
logger.exception(
|
||||
"Audio preprocessing warmup failed (non-fatal): %s. "
|
||||
"First request may experience higher latency.",
|
||||
)
|
||||
|
||||
def _warmup_input_processor(self) -> None:
|
||||
"""Warm up input processor with dummy audio to avoid first-request latency.
|
||||
|
||||
The first call to renderer.render_cmpl() with multimodal audio
|
||||
triggers multimodal processing initialization which can take ~2.5s.
|
||||
This method processes a dummy audio request to warm up the pipeline.
|
||||
"""
|
||||
# Skip warmup if model doesn't support transcription
|
||||
if not supports_transcription(self.model_cls):
|
||||
return
|
||||
|
||||
# Only warm up if model supports transcription methods
|
||||
if not hasattr(self.model_cls, "get_generation_prompt"):
|
||||
return
|
||||
|
||||
try:
|
||||
warmup_start = time.perf_counter()
|
||||
logger.info("Warming up multimodal input processor...")
|
||||
|
||||
# Create minimal dummy audio (1 second of silence)
|
||||
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
|
||||
|
||||
# Use the same method that _preprocess_speech_to_text uses
|
||||
# to create the prompt
|
||||
dummy_prompt = self.model_cls.get_generation_prompt(
|
||||
audio=dummy_audio,
|
||||
stt_config=self.asr_config,
|
||||
model_config=self.model_config,
|
||||
language="en",
|
||||
task_type=self.task_type,
|
||||
request_prompt="",
|
||||
to_language=None,
|
||||
)
|
||||
parsed_prompt = parse_model_prompt(self.model_config, dummy_prompt)
|
||||
|
||||
# Process the dummy input through the input processor
|
||||
# This will trigger all the multimodal processing initialization
|
||||
_ = self.renderer.render_cmpl([parsed_prompt])
|
||||
|
||||
warmup_elapsed = time.perf_counter() - warmup_start
|
||||
logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
|
||||
except Exception:
|
||||
# Don't fail initialization if warmup fails - log warning and continue
|
||||
logger.exception(
|
||||
"Input processor warmup failed (non-fatal): %s. "
|
||||
"First request may experience higher latency."
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def model_cls(self) -> type[SupportsTranscription]:
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
@@ -158,6 +158,56 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
if self._mm_cache_stats is not None:
|
||||
self._mm_cache_stats.reset = True
|
||||
|
||||
def warmup(self, chat_params: ChatParams) -> None:
|
||||
"""
|
||||
Warm up this renderer to avoid first-request latency.
|
||||
|
||||
For chat requests:
|
||||
- Jinja2 template compilation
|
||||
|
||||
For multi-modal requests:
|
||||
- Importing libraries such as librosa triggers JIT compilation.
|
||||
"""
|
||||
try:
|
||||
logger.info("Warming up chat template processing...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
self.render_chat([[{"role": "user", "content": "warmup"}]], chat_params)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
logger.info("Chat template warmup completed in %.3fs", elapsed)
|
||||
except Exception:
|
||||
logger.exception("Chat template warmup failed")
|
||||
|
||||
if self.mm_processor:
|
||||
from vllm.multimodal.processing import TimingContext
|
||||
|
||||
model_config = self.model_config
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
processor = self.mm_processor
|
||||
mm_limits = processor.info.allowed_mm_limits
|
||||
|
||||
try:
|
||||
logger.info("Warming up multi-modal processing...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
|
||||
seq_len=model_config.max_model_len,
|
||||
mm_counts=dict.fromkeys(mm_limits, 1),
|
||||
mm_options=mm_config.limit_per_prompt,
|
||||
)
|
||||
_ = processor.apply(
|
||||
processor_inputs,
|
||||
timing_ctx=TimingContext(enabled=False),
|
||||
)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
logger.info("Multi-modal warmup completed in %.3fs", elapsed)
|
||||
except Exception:
|
||||
logger.exception("Multi-modal warmup failed")
|
||||
finally:
|
||||
self.clear_mm_cache()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
mm_processor_cache = self.mm_processor_cache
|
||||
if mm_processor_cache is not None:
|
||||
|
||||
Reference in New Issue
Block a user