[Model Runner V2] Add WhisperModelState [6/N] (#35790)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-03-11 14:20:38 -07:00
committed by GitHub
parent c77181e534
commit 55eed6b7a5
8 changed files with 233 additions and 21 deletions

View File

@@ -47,8 +47,7 @@ steps:
- python3 offline_inference/audio_language.py --seed 0
- python3 offline_inference/vision_language.py --seed 0
- python3 offline_inference/vision_language_multi_image.py --seed 0
# TODO: uncomment once https://github.com/vllm-project/vllm/pull/35790 is merged.
#- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 # TODO
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
# for pooling models
- python3 pooling/embed/vision_embedding_offline.py --seed 0
# for features demo

View File

@@ -3,6 +3,7 @@
from collections.abc import Sequence
from typing import Any, cast
import numpy as np
import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
@@ -180,6 +181,7 @@ def build_attn_metadata(
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
dcp_local_seq_lens: torch.Tensor | None = None,
encoder_seq_lens: dict[int, tuple[torch.Tensor, np.ndarray]] | None = None,
) -> dict[str, Any]:
seq_lens = seq_lens[:num_reqs]
if dcp_local_seq_lens is not None:
@@ -204,6 +206,10 @@ def build_attn_metadata(
causal=True,
dcp_local_seq_lens=dcp_local_seq_lens,
)
if encoder_seq_lens and i in encoder_seq_lens:
encoder_seq_lens_gpu, encoder_seq_lens_cpu = encoder_seq_lens[i]
common_attn_metadata.encoder_seq_lens = encoder_seq_lens_gpu
common_attn_metadata.encoder_seq_lens_cpu = encoder_seq_lens_cpu
for attn_group in attn_groups[i]:
attn_metadata_builder = attn_group.get_metadata_builder(0)

View File

@@ -389,5 +389,6 @@ def prepare_inputs_to_capture(
slot_mappings,
attn_groups,
kv_cache_config,
for_capture=True,
)
return attn_metadata, slot_mappings_by_layer

View File

@@ -125,6 +125,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.is_encoder_decoder = self.model_config.is_encoder_decoder
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.output_copy_stream = torch.cuda.Stream(self.device)
@@ -159,12 +160,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.supports_mm_inputs and self.is_first_pp_rank:
self.encoder_cache = EncoderCache()
# Speculative decoding.
self.speculator = None
self.num_speculative_steps = 0
self.use_aux_hidden_state_outputs = False
use_strict_rejection_sampling = False
if self.speculative_config is not None:
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
use_strict_rejection_sampling = (
self.speculative_config.rejection_sample_method == "strict"
)
if self.is_last_pp_rank:
self.speculator = init_speculator(self.vllm_config, self.device)
@@ -173,13 +179,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_aux_hidden_state_outputs = True
if self.pp_size > 1:
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
use_strict_rejection_sampling = (
self.speculative_config.rejection_sample_method == "strict"
)
# Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device)
# General request states.
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
@@ -243,7 +247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks: list[SupportedTask] = []
if self.model_config.runner_type == "generate":
tasks.append("generate")
tasks.extend(self.model_state.get_supported_generation_tasks())
if self.pooling_runner is not None:
tasks.extend(self.pooling_runner.get_supported_pooling_tasks())
return tuple(tasks)
@@ -307,11 +311,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for kv_cache_group in kv_cache_config.kv_cache_groups
]
block_table_max_model_len = self.max_model_len
if self.is_encoder_decoder:
# Cross-attention block tables need to index encoder tokens
# (e.g., Whisper ~1500), which can exceed decoder max_model_len.
block_table_max_model_len = max(
block_table_max_model_len,
getattr(self.model_config.hf_config, "max_source_positions", 0),
)
self.block_tables = BlockTables(
block_sizes=block_sizes,
max_num_reqs=self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
max_model_len=block_table_max_model_len,
device=self.device,
cp_size=self.dcp_size,
cp_rank=self.dcp_rank,
@@ -870,6 +883,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
num_tokens_across_dp = None
skip_compiled = False
if self.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs:
# Encoder-decoder models such as Whisper should run eager/non-compiled
# when encoder inputs are scheduled, because this step updates
# cross-attention cache with dynamic encoder outputs.
# Override batch_desc to NONE.
skip_compiled = True
batch_desc = BatchExecutionDescriptor(
cg_mode=CUDAGraphMode.NONE,
num_tokens=num_toks,
num_reqs=num_reqs,
)
if self.dp_size > 1:
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
self.cudagraph_manager,
@@ -984,6 +1010,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings_by_layer,
skip_compiled=skip_compiled,
):
self.kv_connector.pre_forward(scheduler_output)
model_output = self.model(**model_inputs)

View File

@@ -13,6 +13,11 @@ def init_model_state(
encoder_cache: EncoderCache | None,
device: torch.device,
):
if "WhisperForConditionalGeneration" in vllm_config.model_config.architectures:
from vllm.v1.worker.gpu.model_states.whisper import WhisperModelState
return WhisperModelState(vllm_config, model, encoder_cache, device)
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
return DefaultModelState(vllm_config, model, encoder_cache, device)

View File

@@ -109,7 +109,7 @@ class DefaultModelState(ModelState):
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, torch.Tensor | None]:
) -> dict[str, Any]:
if not self.uses_mrope:
# Common case (1D positions).
return {}
@@ -126,9 +126,7 @@ class DefaultModelState(ModelState):
]
return {"positions": mrope_positions}
def prepare_dummy_inputs(
self, num_reqs: int, num_tokens: int
) -> dict[str, torch.Tensor | None]:
def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]:
model_inputs = {}
if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
@@ -146,6 +144,7 @@ class DefaultModelState(ModelState):
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
for_capture: bool = False,
) -> dict[str, Any]:
if cudagraph_mode == CUDAGraphMode.FULL:
# Use padded sizes - padding is handled by model_runner.prepare_attn.

View File

@@ -8,6 +8,7 @@ import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.tasks import GenerationTask
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.input_batch import InputBatch
@@ -27,13 +28,14 @@ class ModelState(ABC):
) -> None:
raise NotImplementedError
@abstractmethod
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
raise NotImplementedError
def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]:
return ("generate",)
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
return None
@abstractmethod
def apply_staged_writes(self) -> None:
raise NotImplementedError
return None
@abstractmethod
def get_mm_embeddings(
@@ -41,19 +43,17 @@ class ModelState(ABC):
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
req_states: RequestState,
) -> torch.Tensor:
) -> torch.Tensor | None:
raise NotImplementedError
@abstractmethod
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, torch.Tensor | None]:
) -> dict[str, Any]:
raise NotImplementedError
@abstractmethod
def prepare_dummy_inputs(
self, num_reqs: int, num_tokens: int
) -> dict[str, torch.Tensor | None]:
def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]:
raise NotImplementedError
@abstractmethod
@@ -65,5 +65,6 @@ class ModelState(ABC):
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
for_capture: bool = False,
) -> dict[str, Any]:
raise NotImplementedError

View File

@@ -0,0 +1,174 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup
class WhisperModelState(ModelState):
def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.scheduler_config = vllm_config.scheduler_config
self.model = model
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.model_config.max_model_len
self.device = device
assert encoder_cache is not None
self.encoder_cache = encoder_cache
self.encoder_runner = EncoderRunner(
model=self.model,
max_num_tokens=self.max_num_tokens,
hidden_size=self.model_config.get_inputs_embeds_size(),
encoder_cache=self.encoder_cache,
dtype=self.model_config.dtype,
device=self.device,
)
self.max_encoder_len = getattr(
self.model_config.hf_config,
"max_source_positions",
self.max_model_len,
)
self.encoder_seq_lens_gpu = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=self.device
)
self.encoder_outputs: list[torch.Tensor] = []
def get_supported_generation_tasks(self):
return ("transcription",)
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
req_states: RequestState,
) -> None:
# Ensure encoder inputs are ordered consistently with input_batch.req_ids.
encoder_inputs: dict[str, list[int]] = {}
for req_id in input_batch.req_ids:
req_encoder_inputs = scheduled_encoder_inputs.get(req_id, [])
if req_encoder_inputs:
encoder_inputs[req_id] = req_encoder_inputs
_, mm_kwargs = self.encoder_runner.prepare_mm_inputs(encoder_inputs)
if mm_kwargs:
# Whisper consumes encoder outputs through `encoder_outputs`, not
# `inputs_embeds`. Single modality (audio) so execute_mm_encoder
# preserves request order; use its return value directly.
# No need to store in encoder_cache: cross-attention K/V are written
# to the KV cache on the first step; decode steps use the cache.
self.encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
else:
# Decode steps: encoder K/V are in cross-attention KV cache.
self.encoder_outputs = []
return None
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, Any]:
model_inputs = {"encoder_outputs": self.encoder_outputs}
self.encoder_outputs = []
return model_inputs
def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]:
return {"encoder_outputs": []}
def prepare_attn(
self,
input_batch: InputBatch,
cudagraph_mode: CUDAGraphMode,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
for_capture: bool = False,
) -> dict[str, Any]:
if cudagraph_mode == CUDAGraphMode.FULL:
num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
else:
num_reqs = input_batch.num_reqs
num_tokens = input_batch.num_tokens
encoder_seq_lens = self._get_encoder_seq_lens(
input_batch.req_ids, attn_groups, for_capture
)
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
encoder_seq_lens=encoder_seq_lens,
)
return attn_metadata
def _get_encoder_seq_lens(
self,
req_ids: list[str],
attn_groups: list[list[AttentionGroup]],
for_capture: bool,
) -> dict[int, tuple[torch.Tensor, np.ndarray]]:
num_reqs = len(req_ids)
encoder_seq_lens_np = np.zeros(num_reqs, dtype=np.int32)
if not for_capture:
# During normal execution, use actual encoder lengths.
for i, req_id in enumerate(req_ids):
mm_features = self.encoder_cache.mm_features.get(req_id, [])
encoder_seq_lens_np[i] = sum(
feature.mm_position.get_num_embeds() for feature in mm_features
)
else:
# During CUDA graph capture, use max encoder length so max_seqlen_k
# is captured with the correct value for cross-attention.
encoder_seq_lens_np[:] = self.max_encoder_len
self.encoder_seq_lens_gpu[:num_reqs].copy_(
torch.from_numpy(encoder_seq_lens_np), non_blocking=True
)
self.encoder_seq_lens_gpu[num_reqs:].fill_(0)
encoder_seq_lens_gpu = self.encoder_seq_lens_gpu[:num_reqs]
seq_lens_by_group: dict[int, tuple[torch.Tensor, np.ndarray]] = {}
for kv_cache_group_idx, groups in enumerate(attn_groups):
has_cross_attn = any(
isinstance(attn_group.kv_cache_spec, CrossAttentionSpec)
for attn_group in groups
)
if has_cross_attn:
seq_lens_by_group[kv_cache_group_idx] = (
encoder_seq_lens_gpu,
encoder_seq_lens_np,
)
return seq_lens_by_group