[Core] Whisper support torch.compile (#30385)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -156,7 +156,9 @@ def test_wer_correctness(
|
||||
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
|
||||
):
|
||||
# TODO refactor to use `ASRDataset`
|
||||
with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server:
|
||||
with RemoteOpenAIServer(
|
||||
model_name, ["--enforce-eager"], max_wait_seconds=480
|
||||
) as remote_server:
|
||||
dataset = load_hf_dataset(dataset_repo)
|
||||
|
||||
if not max_concurrent_request:
|
||||
|
||||
@@ -25,6 +25,7 @@ from vllm.config import (
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
@@ -388,6 +389,12 @@ def _support_torch_compile(
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# If skip_compiled is set, bypass compiled model call. This is used e.g. for
|
||||
# enc-dec models where tensor shapes/types vary across invocations, preventing
|
||||
# the capture of a single computational graph.
|
||||
if is_forward_context_available() and get_forward_context().skip_compiled:
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# if aot_compiled_fn is set, call it with partition wrapper context.
|
||||
# The partition wrapper must be active at runtime for CUDA graph
|
||||
# capture to work correctly with inductor graph partitioning.
|
||||
|
||||
@@ -207,6 +207,9 @@ class ForwardContext:
|
||||
|
||||
ubatch_slices: UBatchSlices | None = None
|
||||
|
||||
# If True, bypass the compiled model call, e.g. by using .forward() directly
|
||||
skip_compiled: bool = False
|
||||
|
||||
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -240,6 +243,7 @@ def create_forward_context(
|
||||
batch_descriptor: BatchDescriptor | None = None,
|
||||
ubatch_slices: UBatchSlices | None = None,
|
||||
additional_kwargs: dict[str, Any] | None = None,
|
||||
skip_compiled: bool = False,
|
||||
):
|
||||
return ForwardContext(
|
||||
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
||||
@@ -249,6 +253,7 @@ def create_forward_context(
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
ubatch_slices=ubatch_slices,
|
||||
skip_compiled=skip_compiled,
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
|
||||
@@ -278,6 +283,7 @@ def set_forward_context(
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: BatchDescriptor | None = None,
|
||||
ubatch_slices: UBatchSlices | None = None,
|
||||
skip_compiled: bool = False,
|
||||
):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
@@ -336,6 +342,7 @@ def set_forward_context(
|
||||
batch_descriptor,
|
||||
ubatch_slices,
|
||||
additional_kwargs,
|
||||
skip_compiled,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -19,6 +19,7 @@ from transformers import (
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -561,6 +562,7 @@ class WhisperEncoder(nn.Module):
|
||||
return self.forward_layers(hidden_states)
|
||||
|
||||
|
||||
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
|
||||
class WhisperDecoder(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@@ -3268,6 +3268,13 @@ class GPUModelRunner(
|
||||
# Mark KV scales as calculated after the first forward pass
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
# Encoder-decoder models can only compile the pure decode steps where no
|
||||
# encoder inputs are present. Use eager for the first pass.
|
||||
num_encoder_reqs = len(scheduler_output.scheduled_encoder_inputs)
|
||||
has_encoder_input = (
|
||||
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
|
||||
)
|
||||
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with (
|
||||
@@ -3279,6 +3286,7 @@ class GPUModelRunner(
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
batch_descriptor=batch_desc,
|
||||
ubatch_slices=ubatch_slices_padded,
|
||||
skip_compiled=has_encoder_input,
|
||||
),
|
||||
record_function_or_nullcontext("gpu_model_runner: forward"),
|
||||
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
||||
|
||||
Reference in New Issue
Block a user