[v1] Add Whisper model support (encoder-decoder) (#21088)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Russell Bryant
2025-09-10 16:53:35 -04:00
committed by GitHub
parent 4db4426404
commit 37e8182bfe
31 changed files with 429 additions and 92 deletions

View File

@@ -8,6 +8,7 @@ import enum
import hashlib
import inspect
import json
import os
import textwrap
import warnings
from collections.abc import Mapping
@@ -41,6 +42,7 @@ from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
from vllm.config.utils import ConfigType, config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
@@ -3509,16 +3511,33 @@ class VllmConfig:
disable_chunked_prefill_reasons: list[str] = []
if self.model_config and self.model_config.pooler_config:
pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
if self.model_config:
if self.model_config.pooler_config:
pooling_type = self.model_config.pooler_config.pooling_type
if pooling_type is None or pooling_type.lower() != "last":
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")
elif self.model_config.is_encoder_decoder:
self.scheduler_config.max_num_encoder_input_tokens = \
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config)
logger.debug(
"Encoder-decoder model detected: setting "
"`max_num_encoder_input_tokens` to encoder length (%s)",
self.scheduler_config.max_num_encoder_input_tokens)
self.scheduler_config.disable_chunked_mm_input = True
disable_chunked_prefill_reasons.append(
"Only \"last\" pooling supports chunked "
"prefill and prefix caching; disabling both.")
elif not getattr(self.model_config.hf_config, "is_causal", True):
disable_chunked_prefill_reasons.append(
"Only models using causal attention supports chunked "
"prefill and prefix caching; disabling both.")
"Encoder-decoder models do not support chunked prefill nor"
" prefix caching; disabling both.")
if (self.model_config.architecture
== "WhisperForConditionalGeneration"
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD")
!= "spawn"):
logger.warning(
"Whisper is known to have issues with "
"forked workers. If startup is hanging, "
"try setting 'VLLM_WORKER_MULTIPROC_METHOD' "
"to 'spawn'.")
if disable_chunked_prefill_reasons:
for reason in disable_chunked_prefill_reasons: