[Core] Whisper support torch.compile (#30385)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -156,7 +156,9 @@ def test_wer_correctness(
|
|||||||
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
|
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
|
||||||
):
|
):
|
||||||
# TODO refactor to use `ASRDataset`
|
# TODO refactor to use `ASRDataset`
|
||||||
with RemoteOpenAIServer(model_name, ["--enforce-eager"]) as remote_server:
|
with RemoteOpenAIServer(
|
||||||
|
model_name, ["--enforce-eager"], max_wait_seconds=480
|
||||||
|
) as remote_server:
|
||||||
dataset = load_hf_dataset(dataset_repo)
|
dataset = load_hf_dataset(dataset_repo)
|
||||||
|
|
||||||
if not max_concurrent_request:
|
if not max_concurrent_request:
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from vllm.config import (
|
|||||||
set_current_vllm_config,
|
set_current_vllm_config,
|
||||||
)
|
)
|
||||||
from vllm.config.compilation import DynamicShapesType
|
from vllm.config.compilation import DynamicShapesType
|
||||||
|
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
@@ -388,6 +389,12 @@ def _support_torch_compile(
|
|||||||
if self.do_not_compile or torch.compiler.is_compiling():
|
if self.do_not_compile or torch.compiler.is_compiling():
|
||||||
return self.forward(*args, **kwargs)
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
# If skip_compiled is set, bypass compiled model call. This is used e.g. for
|
||||||
|
# enc-dec models where tensor shapes/types vary across invocations, preventing
|
||||||
|
# the capture of a single computational graph.
|
||||||
|
if is_forward_context_available() and get_forward_context().skip_compiled:
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
# if aot_compiled_fn is set, call it with partition wrapper context.
|
# if aot_compiled_fn is set, call it with partition wrapper context.
|
||||||
# The partition wrapper must be active at runtime for CUDA graph
|
# The partition wrapper must be active at runtime for CUDA graph
|
||||||
# capture to work correctly with inductor graph partitioning.
|
# capture to work correctly with inductor graph partitioning.
|
||||||
|
|||||||
@@ -207,6 +207,9 @@ class ForwardContext:
|
|||||||
|
|
||||||
ubatch_slices: UBatchSlices | None = None
|
ubatch_slices: UBatchSlices | None = None
|
||||||
|
|
||||||
|
# If True, bypass the compiled model call, e.g. by using .forward() directly
|
||||||
|
skip_compiled: bool = False
|
||||||
|
|
||||||
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
additional_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -240,6 +243,7 @@ def create_forward_context(
|
|||||||
batch_descriptor: BatchDescriptor | None = None,
|
batch_descriptor: BatchDescriptor | None = None,
|
||||||
ubatch_slices: UBatchSlices | None = None,
|
ubatch_slices: UBatchSlices | None = None,
|
||||||
additional_kwargs: dict[str, Any] | None = None,
|
additional_kwargs: dict[str, Any] | None = None,
|
||||||
|
skip_compiled: bool = False,
|
||||||
):
|
):
|
||||||
return ForwardContext(
|
return ForwardContext(
|
||||||
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
||||||
@@ -249,6 +253,7 @@ def create_forward_context(
|
|||||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
ubatch_slices=ubatch_slices,
|
ubatch_slices=ubatch_slices,
|
||||||
|
skip_compiled=skip_compiled,
|
||||||
additional_kwargs=additional_kwargs or {},
|
additional_kwargs=additional_kwargs or {},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -278,6 +283,7 @@ def set_forward_context(
|
|||||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor: BatchDescriptor | None = None,
|
batch_descriptor: BatchDescriptor | None = None,
|
||||||
ubatch_slices: UBatchSlices | None = None,
|
ubatch_slices: UBatchSlices | None = None,
|
||||||
|
skip_compiled: bool = False,
|
||||||
):
|
):
|
||||||
"""A context manager that stores the current forward context,
|
"""A context manager that stores the current forward context,
|
||||||
can be attention metadata, etc.
|
can be attention metadata, etc.
|
||||||
@@ -336,6 +342,7 @@ def set_forward_context(
|
|||||||
batch_descriptor,
|
batch_descriptor,
|
||||||
ubatch_slices,
|
ubatch_slices,
|
||||||
additional_kwargs,
|
additional_kwargs,
|
||||||
|
skip_compiled,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from transformers import (
|
|||||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||||
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@@ -561,6 +562,7 @@ class WhisperEncoder(nn.Module):
|
|||||||
return self.forward_layers(hidden_states)
|
return self.forward_layers(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
|
||||||
class WhisperDecoder(nn.Module):
|
class WhisperDecoder(nn.Module):
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -3268,6 +3268,13 @@ class GPUModelRunner(
|
|||||||
# Mark KV scales as calculated after the first forward pass
|
# Mark KV scales as calculated after the first forward pass
|
||||||
self.calculate_kv_scales = False
|
self.calculate_kv_scales = False
|
||||||
|
|
||||||
|
# Encoder-decoder models can only compile the pure decode steps where no
|
||||||
|
# encoder inputs are present. Use eager for the first pass.
|
||||||
|
num_encoder_reqs = len(scheduler_output.scheduled_encoder_inputs)
|
||||||
|
has_encoder_input = (
|
||||||
|
self.model_config.is_encoder_decoder and num_encoder_reqs > 0
|
||||||
|
)
|
||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
# Use persistent buffers for CUDA graphs.
|
# Use persistent buffers for CUDA graphs.
|
||||||
with (
|
with (
|
||||||
@@ -3279,6 +3286,7 @@ class GPUModelRunner(
|
|||||||
cudagraph_runtime_mode=cudagraph_mode,
|
cudagraph_runtime_mode=cudagraph_mode,
|
||||||
batch_descriptor=batch_desc,
|
batch_descriptor=batch_desc,
|
||||||
ubatch_slices=ubatch_slices_padded,
|
ubatch_slices=ubatch_slices_padded,
|
||||||
|
skip_compiled=has_encoder_input,
|
||||||
),
|
),
|
||||||
record_function_or_nullcontext("gpu_model_runner: forward"),
|
record_function_or_nullcontext("gpu_model_runner: forward"),
|
||||||
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
||||||
|
|||||||
Reference in New Issue
Block a user