diff --git a/.buildkite/test_areas/model_runner_v2.yaml b/.buildkite/test_areas/model_runner_v2.yaml index fa05e2247..e19b7297f 100644 --- a/.buildkite/test_areas/model_runner_v2.yaml +++ b/.buildkite/test_areas/model_runner_v2.yaml @@ -47,8 +47,7 @@ steps: - python3 offline_inference/audio_language.py --seed 0 - python3 offline_inference/vision_language.py --seed 0 - python3 offline_inference/vision_language_multi_image.py --seed 0 - # TODO: uncomment once https://github.com/vllm-project/vllm/pull/35790 is merged. - #- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 # TODO + - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 # for pooling models - python3 pooling/embed/vision_embedding_offline.py --seed 0 # for features demo diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py index d9fc4515b..5354ef088 100644 --- a/vllm/v1/worker/gpu/attn_utils.py +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -3,6 +3,7 @@ from collections.abc import Sequence from typing import Any, cast +import numpy as np import torch from vllm.config import VllmConfig, get_layers_from_vllm_config @@ -180,6 +181,7 @@ def build_attn_metadata( slot_mappings: torch.Tensor, kv_cache_config: KVCacheConfig, dcp_local_seq_lens: torch.Tensor | None = None, + encoder_seq_lens: dict[int, tuple[torch.Tensor, np.ndarray]] | None = None, ) -> dict[str, Any]: seq_lens = seq_lens[:num_reqs] if dcp_local_seq_lens is not None: @@ -204,6 +206,10 @@ def build_attn_metadata( causal=True, dcp_local_seq_lens=dcp_local_seq_lens, ) + if encoder_seq_lens and i in encoder_seq_lens: + encoder_seq_lens_gpu, encoder_seq_lens_cpu = encoder_seq_lens[i] + common_attn_metadata.encoder_seq_lens = encoder_seq_lens_gpu + common_attn_metadata.encoder_seq_lens_cpu = encoder_seq_lens_cpu for attn_group in attn_groups[i]: attn_metadata_builder = attn_group.get_metadata_builder(0) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 202470c7b..3b44d580d 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -389,5 +389,6 @@ def prepare_inputs_to_capture( slot_mappings, attn_groups, kv_cache_config, + for_capture=True, ) return attn_metadata, slot_mappings_by_layer diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index ca2aacfc3..d751e83ba 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -125,6 +125,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.max_model_len = self.model_config.max_model_len self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_reqs = self.scheduler_config.max_num_seqs + self.is_encoder_decoder = self.model_config.is_encoder_decoder self.use_async_scheduling = self.scheduler_config.async_scheduling self.output_copy_stream = torch.cuda.Stream(self.device) @@ -159,12 +160,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.supports_mm_inputs and self.is_first_pp_rank: self.encoder_cache = EncoderCache() + # Speculative decoding. self.speculator = None self.num_speculative_steps = 0 self.use_aux_hidden_state_outputs = False use_strict_rejection_sampling = False if self.speculative_config is not None: self.num_speculative_steps = self.speculative_config.num_speculative_tokens + use_strict_rejection_sampling = ( + self.speculative_config.rejection_sample_method == "strict" + ) + if self.is_last_pp_rank: self.speculator = init_speculator(self.vllm_config, self.device) @@ -173,13 +179,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.use_aux_hidden_state_outputs = True if self.pp_size > 1: raise ValueError("EAGLE3 with pipeline parallel is not supported.") - use_strict_rejection_sampling = ( - self.speculative_config.rejection_sample_method == "strict" - ) # Draft tokens propagation - for spec-dec + struct outputs. self.draft_tokens_handler = DraftTokensHandler(self.device) + # General request states. self.req_states = RequestState( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, @@ -243,7 +247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): def get_supported_tasks(self) -> tuple[SupportedTask, ...]: tasks: list[SupportedTask] = [] if self.model_config.runner_type == "generate": - tasks.append("generate") + tasks.extend(self.model_state.get_supported_generation_tasks()) if self.pooling_runner is not None: tasks.extend(self.pooling_runner.get_supported_pooling_tasks()) return tuple(tasks) @@ -307,11 +311,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): for kv_cache_group in kv_cache_config.kv_cache_groups ] + block_table_max_model_len = self.max_model_len + if self.is_encoder_decoder: + # Cross-attention block tables need to index encoder tokens + # (e.g., Whisper ~1500), which can exceed decoder max_model_len. + block_table_max_model_len = max( + block_table_max_model_len, + getattr(self.model_config.hf_config, "max_source_positions", 0), + ) + self.block_tables = BlockTables( block_sizes=block_sizes, max_num_reqs=self.max_num_reqs, max_num_batched_tokens=self.max_num_tokens, - max_model_len=self.max_model_len, + max_model_len=block_table_max_model_len, device=self.device, cp_size=self.dcp_size, cp_rank=self.dcp_rank, @@ -870,6 +883,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) num_tokens_across_dp = None + skip_compiled = False + if self.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs: + # Encoder-decoder models such as Whisper should run eager/non-compiled + # when encoder inputs are scheduled, because this step updates + # cross-attention cache with dynamic encoder outputs. + # Override batch_desc to NONE. + skip_compiled = True + batch_desc = BatchExecutionDescriptor( + cg_mode=CUDAGraphMode.NONE, + num_tokens=num_toks, + num_reqs=num_reqs, + ) + if self.dp_size > 1: batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding( self.cudagraph_manager, @@ -984,6 +1010,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens_across_dp=num_tokens_across_dp, batch_descriptor=batch_descriptor, slot_mapping=slot_mappings_by_layer, + skip_compiled=skip_compiled, ): self.kv_connector.pre_forward(scheduler_output) model_output = self.model(**model_inputs) diff --git a/vllm/v1/worker/gpu/model_states/__init__.py b/vllm/v1/worker/gpu/model_states/__init__.py index 3ddce0fdc..651452553 100644 --- a/vllm/v1/worker/gpu/model_states/__init__.py +++ b/vllm/v1/worker/gpu/model_states/__init__.py @@ -13,6 +13,11 @@ def init_model_state( encoder_cache: EncoderCache | None, device: torch.device, ): + if "WhisperForConditionalGeneration" in vllm_config.model_config.architectures: + from vllm.v1.worker.gpu.model_states.whisper import WhisperModelState + + return WhisperModelState(vllm_config, model, encoder_cache, device) + from vllm.v1.worker.gpu.model_states.default import DefaultModelState return DefaultModelState(vllm_config, model, encoder_cache, device) diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index 6d24c3663..783d225c4 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -109,7 +109,7 @@ class DefaultModelState(ModelState): def prepare_inputs( self, input_batch: InputBatch, req_states: RequestState - ) -> dict[str, torch.Tensor | None]: + ) -> dict[str, Any]: if not self.uses_mrope: # Common case (1D positions). return {} @@ -126,9 +126,7 @@ class DefaultModelState(ModelState): ] return {"positions": mrope_positions} - def prepare_dummy_inputs( - self, num_reqs: int, num_tokens: int - ) -> dict[str, torch.Tensor | None]: + def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]: model_inputs = {} if self.supports_mm_inputs: inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens] @@ -146,6 +144,7 @@ class DefaultModelState(ModelState): slot_mappings: torch.Tensor, attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, + for_capture: bool = False, ) -> dict[str, Any]: if cudagraph_mode == CUDAGraphMode.FULL: # Use padded sizes - padding is handled by model_runner.prepare_attn. diff --git a/vllm/v1/worker/gpu/model_states/interface.py b/vllm/v1/worker/gpu/model_states/interface.py index 064cfa195..1c114496d 100644 --- a/vllm/v1/worker/gpu/model_states/interface.py +++ b/vllm/v1/worker/gpu/model_states/interface.py @@ -8,6 +8,7 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode +from vllm.tasks import GenerationTask from vllm.v1.core.sched.output import NewRequestData from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.input_batch import InputBatch @@ -27,13 +28,14 @@ class ModelState(ABC): ) -> None: raise NotImplementedError - @abstractmethod - def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: - raise NotImplementedError + def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]: + return ("generate",) + + def add_request(self, req_index: int, new_req_data: NewRequestData) -> None: + return None - @abstractmethod def apply_staged_writes(self) -> None: - raise NotImplementedError + return None @abstractmethod def get_mm_embeddings( @@ -41,19 +43,17 @@ class ModelState(ABC): scheduled_encoder_inputs: dict[str, list[int]], input_batch: InputBatch, req_states: RequestState, - ) -> torch.Tensor: + ) -> torch.Tensor | None: raise NotImplementedError @abstractmethod def prepare_inputs( self, input_batch: InputBatch, req_states: RequestState - ) -> dict[str, torch.Tensor | None]: + ) -> dict[str, Any]: raise NotImplementedError @abstractmethod - def prepare_dummy_inputs( - self, num_reqs: int, num_tokens: int - ) -> dict[str, torch.Tensor | None]: + def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]: raise NotImplementedError @abstractmethod @@ -65,5 +65,6 @@ class ModelState(ABC): slot_mappings: torch.Tensor, attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, + for_capture: bool = False, ) -> dict[str, Any]: raise NotImplementedError diff --git a/vllm/v1/worker/gpu/model_states/whisper.py b/vllm/v1/worker/gpu/model_states/whisper.py new file mode 100644 index 000000000..1268fee88 --- /dev/null +++ b/vllm/v1/worker/gpu/model_states/whisper.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import numpy as np +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.config.compilation import CUDAGraphMode +from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheConfig +from vllm.v1.worker.gpu.attn_utils import build_attn_metadata +from vllm.v1.worker.gpu.input_batch import InputBatch +from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache +from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner +from vllm.v1.worker.gpu.model_states.interface import ModelState +from vllm.v1.worker.gpu.states import RequestState +from vllm.v1.worker.utils import AttentionGroup + + +class WhisperModelState(ModelState): + def __init__( + self, + vllm_config: VllmConfig, + model: nn.Module, + encoder_cache: EncoderCache | None, + device: torch.device, + ) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.scheduler_config = vllm_config.scheduler_config + self.model = model + self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs + self.max_num_tokens = self.scheduler_config.max_num_batched_tokens + self.max_model_len = self.model_config.max_model_len + self.device = device + + assert encoder_cache is not None + self.encoder_cache = encoder_cache + self.encoder_runner = EncoderRunner( + model=self.model, + max_num_tokens=self.max_num_tokens, + hidden_size=self.model_config.get_inputs_embeds_size(), + encoder_cache=self.encoder_cache, + dtype=self.model_config.dtype, + device=self.device, + ) + + self.max_encoder_len = getattr( + self.model_config.hf_config, + "max_source_positions", + self.max_model_len, + ) + self.encoder_seq_lens_gpu = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=self.device + ) + + self.encoder_outputs: list[torch.Tensor] = [] + + def get_supported_generation_tasks(self): + return ("transcription",) + + def get_mm_embeddings( + self, + scheduled_encoder_inputs: dict[str, list[int]], + input_batch: InputBatch, + req_states: RequestState, + ) -> None: + # Ensure encoder inputs are ordered consistently with input_batch.req_ids. + encoder_inputs: dict[str, list[int]] = {} + for req_id in input_batch.req_ids: + req_encoder_inputs = scheduled_encoder_inputs.get(req_id, []) + if req_encoder_inputs: + encoder_inputs[req_id] = req_encoder_inputs + _, mm_kwargs = self.encoder_runner.prepare_mm_inputs(encoder_inputs) + if mm_kwargs: + # Whisper consumes encoder outputs through `encoder_outputs`, not + # `inputs_embeds`. Single modality (audio) so execute_mm_encoder + # preserves request order; use its return value directly. + # No need to store in encoder_cache: cross-attention K/V are written + # to the KV cache on the first step; decode steps use the cache. + self.encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs) + else: + # Decode steps: encoder K/V are in cross-attention KV cache. + self.encoder_outputs = [] + return None + + def prepare_inputs( + self, input_batch: InputBatch, req_states: RequestState + ) -> dict[str, Any]: + model_inputs = {"encoder_outputs": self.encoder_outputs} + self.encoder_outputs = [] + return model_inputs + + def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]: + return {"encoder_outputs": []} + + def prepare_attn( + self, + input_batch: InputBatch, + cudagraph_mode: CUDAGraphMode, + block_tables: tuple[torch.Tensor, ...], + slot_mappings: torch.Tensor, + attn_groups: list[list[AttentionGroup]], + kv_cache_config: KVCacheConfig, + for_capture: bool = False, + ) -> dict[str, Any]: + if cudagraph_mode == CUDAGraphMode.FULL: + num_reqs = input_batch.num_reqs_after_padding + num_tokens = input_batch.num_tokens_after_padding + else: + num_reqs = input_batch.num_reqs + num_tokens = input_batch.num_tokens + encoder_seq_lens = self._get_encoder_seq_lens( + input_batch.req_ids, attn_groups, for_capture + ) + + query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np) + max_query_len = input_batch.num_scheduled_tokens.max().item() + attn_metadata = build_attn_metadata( + attn_groups=attn_groups, + num_reqs=num_reqs, + num_tokens=num_tokens, + query_start_loc_gpu=input_batch.query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + max_query_len=max_query_len, + seq_lens=input_batch.seq_lens, + max_seq_len=self.max_model_len, + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=kv_cache_config, + dcp_local_seq_lens=input_batch.dcp_local_seq_lens, + encoder_seq_lens=encoder_seq_lens, + ) + return attn_metadata + + def _get_encoder_seq_lens( + self, + req_ids: list[str], + attn_groups: list[list[AttentionGroup]], + for_capture: bool, + ) -> dict[int, tuple[torch.Tensor, np.ndarray]]: + num_reqs = len(req_ids) + encoder_seq_lens_np = np.zeros(num_reqs, dtype=np.int32) + if not for_capture: + # During normal execution, use actual encoder lengths. + for i, req_id in enumerate(req_ids): + mm_features = self.encoder_cache.mm_features.get(req_id, []) + encoder_seq_lens_np[i] = sum( + feature.mm_position.get_num_embeds() for feature in mm_features + ) + else: + # During CUDA graph capture, use max encoder length so max_seqlen_k + # is captured with the correct value for cross-attention. + encoder_seq_lens_np[:] = self.max_encoder_len + + self.encoder_seq_lens_gpu[:num_reqs].copy_( + torch.from_numpy(encoder_seq_lens_np), non_blocking=True + ) + self.encoder_seq_lens_gpu[num_reqs:].fill_(0) + encoder_seq_lens_gpu = self.encoder_seq_lens_gpu[:num_reqs] + + seq_lens_by_group: dict[int, tuple[torch.Tensor, np.ndarray]] = {} + for kv_cache_group_idx, groups in enumerate(attn_groups): + has_cross_attn = any( + isinstance(attn_group.kv_cache_spec, CrossAttentionSpec) + for attn_group in groups + ) + if has_cross_attn: + seq_lens_by_group[kv_cache_group_idx] = ( + encoder_seq_lens_gpu, + encoder_seq_lens_np, + ) + return seq_lens_by_group