[v1] Add Whisper model support (encoder-decoder) (#21088)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Russell Bryant
2025-09-10 16:53:35 -04:00
committed by GitHub
parent 4db4426404
commit 37e8182bfe
31 changed files with 429 additions and 92 deletions

View File

@@ -317,8 +317,8 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device) -> None:
self.kv_cache_spec = kv_cache_spec
self.vllm_config = vllm_config
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.scheduler_config = vllm_config.scheduler_config
# For reorder

View File

@@ -177,12 +177,11 @@ class FlashAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)

View File

@@ -163,11 +163,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.device = device
self.vllm_config = vllm_config
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.cache_config = vllm_config.cache_config
self.model_config = vllm_config.model_config
self.kv_cache_spec = kv_cache_spec
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape)

View File

@@ -516,10 +516,11 @@ class FlexAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)

View File

@@ -39,8 +39,8 @@ class LinearAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
def build(self,
common_prefix_len: int,

View File

@@ -22,12 +22,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.device = device
self.vllm_config = vllm_config
self.layer_names = layer_names
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
@@ -52,4 +49,4 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
m.max_query_len = 1 # decode-only
return self.build(0, m)
return self.build(0, m)

View File

@@ -236,11 +236,11 @@ class AiterFlashAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads(
self.parallel_config)
@@ -248,7 +248,6 @@ class AiterFlashAttentionMetadataBuilder(
self.parallel_config)
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None

View File

@@ -45,8 +45,8 @@ class ShortConvAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
def build(self,
common_prefix_len: int,

View File

@@ -165,7 +165,8 @@ class TreeAttentionMetadataBuilder(
vllm_config: VllmConfig,
device: torch.device,
):
self.kv_cache_spec = kv_cache_spec
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
spec_config = vllm_config.speculative_config

View File

@@ -66,9 +66,9 @@ class TritonAttentionMetadataBuilder(
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.device = device
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
model_config = vllm_config.model_config
self.num_heads_q = model_config.get_num_attention_heads(

View File

@@ -72,6 +72,9 @@ class CommonAttentionMetadata:
logits_indices_padded: Optional[torch.Tensor] = None
num_logits_indices: Optional[int] = None
# Needed by CrossAttentionBuilder
encoder_seq_lens: Optional[np.ndarray] = None
@dataclass
class UbatchSlice:
@@ -193,6 +196,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
self.kv_cache_spec = kv_cache_spec
self.layer_names = layer_names
self.vllm_config = vllm_config
self.device = device
@abstractmethod
def build(self,

View File

@@ -206,8 +206,9 @@ class XFormersAttentionMetadataBuilder(
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert XFORMERS_AVAILABLE
self.kv_cache_spec = kv_cache_spec
self.block_size = kv_cache_spec.block_size
self._num_decodes = 0
self._num_decode_tokens = 0