[Core] Whisper support torch.compile (#30385)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from transformers import (
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -561,6 +562,7 @@ class WhisperEncoder(nn.Module):
|
||||
return self.forward_layers(hidden_states)
|
||||
|
||||
|
||||
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
|
||||
class WhisperDecoder(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user