[Model Runner V2] Add WhisperModelState [6/N] (#35790)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -47,8 +47,7 @@ steps:
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
# TODO: uncomment once https://github.com/vllm-project/vllm/pull/35790 is merged.
|
||||
#- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 # TODO
|
||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||
# for pooling models
|
||||
- python3 pooling/embed/vision_embedding_offline.py --seed 0
|
||||
# for features demo
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
@@ -180,6 +181,7 @@ def build_attn_metadata(
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
dcp_local_seq_lens: torch.Tensor | None = None,
|
||||
encoder_seq_lens: dict[int, tuple[torch.Tensor, np.ndarray]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
seq_lens = seq_lens[:num_reqs]
|
||||
if dcp_local_seq_lens is not None:
|
||||
@@ -204,6 +206,10 @@ def build_attn_metadata(
|
||||
causal=True,
|
||||
dcp_local_seq_lens=dcp_local_seq_lens,
|
||||
)
|
||||
if encoder_seq_lens and i in encoder_seq_lens:
|
||||
encoder_seq_lens_gpu, encoder_seq_lens_cpu = encoder_seq_lens[i]
|
||||
common_attn_metadata.encoder_seq_lens = encoder_seq_lens_gpu
|
||||
common_attn_metadata.encoder_seq_lens_cpu = encoder_seq_lens_cpu
|
||||
|
||||
for attn_group in attn_groups[i]:
|
||||
attn_metadata_builder = attn_group.get_metadata_builder(0)
|
||||
|
||||
@@ -389,5 +389,6 @@ def prepare_inputs_to_capture(
|
||||
slot_mappings,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
for_capture=True,
|
||||
)
|
||||
return attn_metadata, slot_mappings_by_layer
|
||||
|
||||
@@ -125,6 +125,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.is_encoder_decoder = self.model_config.is_encoder_decoder
|
||||
|
||||
self.use_async_scheduling = self.scheduler_config.async_scheduling
|
||||
self.output_copy_stream = torch.cuda.Stream(self.device)
|
||||
@@ -159,12 +160,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.supports_mm_inputs and self.is_first_pp_rank:
|
||||
self.encoder_cache = EncoderCache()
|
||||
|
||||
# Speculative decoding.
|
||||
self.speculator = None
|
||||
self.num_speculative_steps = 0
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
use_strict_rejection_sampling = False
|
||||
if self.speculative_config is not None:
|
||||
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
|
||||
use_strict_rejection_sampling = (
|
||||
self.speculative_config.rejection_sample_method == "strict"
|
||||
)
|
||||
|
||||
if self.is_last_pp_rank:
|
||||
self.speculator = init_speculator(self.vllm_config, self.device)
|
||||
|
||||
@@ -173,13 +179,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
if self.pp_size > 1:
|
||||
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
|
||||
use_strict_rejection_sampling = (
|
||||
self.speculative_config.rejection_sample_method == "strict"
|
||||
)
|
||||
|
||||
# Draft tokens propagation - for spec-dec + struct outputs.
|
||||
self.draft_tokens_handler = DraftTokensHandler(self.device)
|
||||
|
||||
# General request states.
|
||||
self.req_states = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
@@ -243,7 +247,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
tasks: list[SupportedTask] = []
|
||||
if self.model_config.runner_type == "generate":
|
||||
tasks.append("generate")
|
||||
tasks.extend(self.model_state.get_supported_generation_tasks())
|
||||
if self.pooling_runner is not None:
|
||||
tasks.extend(self.pooling_runner.get_supported_pooling_tasks())
|
||||
return tuple(tasks)
|
||||
@@ -307,11 +311,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
|
||||
block_table_max_model_len = self.max_model_len
|
||||
if self.is_encoder_decoder:
|
||||
# Cross-attention block tables need to index encoder tokens
|
||||
# (e.g., Whisper ~1500), which can exceed decoder max_model_len.
|
||||
block_table_max_model_len = max(
|
||||
block_table_max_model_len,
|
||||
getattr(self.model_config.hf_config, "max_source_positions", 0),
|
||||
)
|
||||
|
||||
self.block_tables = BlockTables(
|
||||
block_sizes=block_sizes,
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
max_model_len=block_table_max_model_len,
|
||||
device=self.device,
|
||||
cp_size=self.dcp_size,
|
||||
cp_rank=self.dcp_rank,
|
||||
@@ -870,6 +883,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
num_tokens_across_dp = None
|
||||
|
||||
skip_compiled = False
|
||||
if self.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs:
|
||||
# Encoder-decoder models such as Whisper should run eager/non-compiled
|
||||
# when encoder inputs are scheduled, because this step updates
|
||||
# cross-attention cache with dynamic encoder outputs.
|
||||
# Override batch_desc to NONE.
|
||||
skip_compiled = True
|
||||
batch_desc = BatchExecutionDescriptor(
|
||||
cg_mode=CUDAGraphMode.NONE,
|
||||
num_tokens=num_toks,
|
||||
num_reqs=num_reqs,
|
||||
)
|
||||
|
||||
if self.dp_size > 1:
|
||||
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
|
||||
self.cudagraph_manager,
|
||||
@@ -984,6 +1010,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=slot_mappings_by_layer,
|
||||
skip_compiled=skip_compiled,
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
model_output = self.model(**model_inputs)
|
||||
|
||||
@@ -13,6 +13,11 @@ def init_model_state(
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
):
|
||||
if "WhisperForConditionalGeneration" in vllm_config.model_config.architectures:
|
||||
from vllm.v1.worker.gpu.model_states.whisper import WhisperModelState
|
||||
|
||||
return WhisperModelState(vllm_config, model, encoder_cache, device)
|
||||
|
||||
from vllm.v1.worker.gpu.model_states.default import DefaultModelState
|
||||
|
||||
return DefaultModelState(vllm_config, model, encoder_cache, device)
|
||||
|
||||
@@ -109,7 +109,7 @@ class DefaultModelState(ModelState):
|
||||
|
||||
def prepare_inputs(
|
||||
self, input_batch: InputBatch, req_states: RequestState
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
) -> dict[str, Any]:
|
||||
if not self.uses_mrope:
|
||||
# Common case (1D positions).
|
||||
return {}
|
||||
@@ -126,9 +126,7 @@ class DefaultModelState(ModelState):
|
||||
]
|
||||
return {"positions": mrope_positions}
|
||||
|
||||
def prepare_dummy_inputs(
|
||||
self, num_reqs: int, num_tokens: int
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]:
|
||||
model_inputs = {}
|
||||
if self.supports_mm_inputs:
|
||||
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
|
||||
@@ -146,6 +144,7 @@ class DefaultModelState(ModelState):
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
for_capture: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
if cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# Use padded sizes - padding is handled by model_runner.prepare_attn.
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.tasks import GenerationTask
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
@@ -27,13 +28,14 @@ class ModelState(ABC):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
raise NotImplementedError
|
||||
def get_supported_generation_tasks(self) -> tuple[GenerationTask, ...]:
|
||||
return ("generate",)
|
||||
|
||||
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def apply_staged_writes(self) -> None:
|
||||
raise NotImplementedError
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def get_mm_embeddings(
|
||||
@@ -41,19 +43,17 @@ class ModelState(ABC):
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> torch.Tensor:
|
||||
) -> torch.Tensor | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_inputs(
|
||||
self, input_batch: InputBatch, req_states: RequestState
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_dummy_inputs(
|
||||
self, num_reqs: int, num_tokens: int
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@@ -65,5 +65,6 @@ class ModelState(ABC):
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
for_capture: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
174
vllm/v1/worker/gpu/model_states/whisper.py
Normal file
174
vllm/v1/worker/gpu/model_states/whisper.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
|
||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class WhisperModelState(ModelState):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model: nn.Module,
|
||||
encoder_cache: EncoderCache | None,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.model = model
|
||||
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
self.device = device
|
||||
|
||||
assert encoder_cache is not None
|
||||
self.encoder_cache = encoder_cache
|
||||
self.encoder_runner = EncoderRunner(
|
||||
model=self.model,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.model_config.get_inputs_embeds_size(),
|
||||
encoder_cache=self.encoder_cache,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.max_encoder_len = getattr(
|
||||
self.model_config.hf_config,
|
||||
"max_source_positions",
|
||||
self.max_model_len,
|
||||
)
|
||||
self.encoder_seq_lens_gpu = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self.encoder_outputs: list[torch.Tensor] = []
|
||||
|
||||
def get_supported_generation_tasks(self):
|
||||
return ("transcription",)
|
||||
|
||||
def get_mm_embeddings(
|
||||
self,
|
||||
scheduled_encoder_inputs: dict[str, list[int]],
|
||||
input_batch: InputBatch,
|
||||
req_states: RequestState,
|
||||
) -> None:
|
||||
# Ensure encoder inputs are ordered consistently with input_batch.req_ids.
|
||||
encoder_inputs: dict[str, list[int]] = {}
|
||||
for req_id in input_batch.req_ids:
|
||||
req_encoder_inputs = scheduled_encoder_inputs.get(req_id, [])
|
||||
if req_encoder_inputs:
|
||||
encoder_inputs[req_id] = req_encoder_inputs
|
||||
_, mm_kwargs = self.encoder_runner.prepare_mm_inputs(encoder_inputs)
|
||||
if mm_kwargs:
|
||||
# Whisper consumes encoder outputs through `encoder_outputs`, not
|
||||
# `inputs_embeds`. Single modality (audio) so execute_mm_encoder
|
||||
# preserves request order; use its return value directly.
|
||||
# No need to store in encoder_cache: cross-attention K/V are written
|
||||
# to the KV cache on the first step; decode steps use the cache.
|
||||
self.encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
|
||||
else:
|
||||
# Decode steps: encoder K/V are in cross-attention KV cache.
|
||||
self.encoder_outputs = []
|
||||
return None
|
||||
|
||||
def prepare_inputs(
|
||||
self, input_batch: InputBatch, req_states: RequestState
|
||||
) -> dict[str, Any]:
|
||||
model_inputs = {"encoder_outputs": self.encoder_outputs}
|
||||
self.encoder_outputs = []
|
||||
return model_inputs
|
||||
|
||||
def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]:
|
||||
return {"encoder_outputs": []}
|
||||
|
||||
def prepare_attn(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
cudagraph_mode: CUDAGraphMode,
|
||||
block_tables: tuple[torch.Tensor, ...],
|
||||
slot_mappings: torch.Tensor,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
for_capture: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
if cudagraph_mode == CUDAGraphMode.FULL:
|
||||
num_reqs = input_batch.num_reqs_after_padding
|
||||
num_tokens = input_batch.num_tokens_after_padding
|
||||
else:
|
||||
num_reqs = input_batch.num_reqs
|
||||
num_tokens = input_batch.num_tokens
|
||||
encoder_seq_lens = self._get_encoder_seq_lens(
|
||||
input_batch.req_ids, attn_groups, for_capture
|
||||
)
|
||||
|
||||
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
|
||||
max_query_len = input_batch.num_scheduled_tokens.max().item()
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=input_batch.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=kv_cache_config,
|
||||
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def _get_encoder_seq_lens(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
for_capture: bool,
|
||||
) -> dict[int, tuple[torch.Tensor, np.ndarray]]:
|
||||
num_reqs = len(req_ids)
|
||||
encoder_seq_lens_np = np.zeros(num_reqs, dtype=np.int32)
|
||||
if not for_capture:
|
||||
# During normal execution, use actual encoder lengths.
|
||||
for i, req_id in enumerate(req_ids):
|
||||
mm_features = self.encoder_cache.mm_features.get(req_id, [])
|
||||
encoder_seq_lens_np[i] = sum(
|
||||
feature.mm_position.get_num_embeds() for feature in mm_features
|
||||
)
|
||||
else:
|
||||
# During CUDA graph capture, use max encoder length so max_seqlen_k
|
||||
# is captured with the correct value for cross-attention.
|
||||
encoder_seq_lens_np[:] = self.max_encoder_len
|
||||
|
||||
self.encoder_seq_lens_gpu[:num_reqs].copy_(
|
||||
torch.from_numpy(encoder_seq_lens_np), non_blocking=True
|
||||
)
|
||||
self.encoder_seq_lens_gpu[num_reqs:].fill_(0)
|
||||
encoder_seq_lens_gpu = self.encoder_seq_lens_gpu[:num_reqs]
|
||||
|
||||
seq_lens_by_group: dict[int, tuple[torch.Tensor, np.ndarray]] = {}
|
||||
for kv_cache_group_idx, groups in enumerate(attn_groups):
|
||||
has_cross_attn = any(
|
||||
isinstance(attn_group.kv_cache_spec, CrossAttentionSpec)
|
||||
for attn_group in groups
|
||||
)
|
||||
if has_cross_attn:
|
||||
seq_lens_by_group[kv_cache_group_idx] = (
|
||||
encoder_seq_lens_gpu,
|
||||
encoder_seq_lens_np,
|
||||
)
|
||||
return seq_lens_by_group
|
||||
Reference in New Issue
Block a user