[v1] Add Whisper model support (encoder-decoder) (#21088)
Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -317,8 +317,8 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device) -> None:
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.vllm_config = vllm_config
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
# For reorder
|
||||
|
||||
@@ -177,12 +177,11 @@ class FlashAttentionMetadataBuilder(
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.device = device
|
||||
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
|
||||
@@ -163,11 +163,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self._workspace_buffer = None
|
||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||
self._decode_wrapper = None # Wrapper for decode (general shape)
|
||||
|
||||
@@ -516,10 +516,11 @@ class FlexAttentionMetadataBuilder(
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.device = device
|
||||
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
|
||||
@@ -39,8 +39,8 @@ class LinearAttentionMetadataBuilder(
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
|
||||
@@ -22,12 +22,9 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
self.layer_names = layer_names
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs,
|
||||
@@ -52,4 +49,4 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
return self.build(0, m)
|
||||
|
||||
@@ -236,11 +236,11 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.device = device
|
||||
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
@@ -248,7 +248,6 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
self.parallel_config)
|
||||
self.headdim = self.model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
@@ -45,8 +45,8 @@ class ShortConvAttentionMetadataBuilder(
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
|
||||
@@ -165,7 +165,8 @@ class TreeAttentionMetadataBuilder(
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
spec_config = vllm_config.speculative_config
|
||||
|
||||
@@ -66,9 +66,9 @@ class TritonAttentionMetadataBuilder(
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.device = device
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
|
||||
@@ -72,6 +72,9 @@ class CommonAttentionMetadata:
|
||||
logits_indices_padded: Optional[torch.Tensor] = None
|
||||
num_logits_indices: Optional[int] = None
|
||||
|
||||
# Needed by CrossAttentionBuilder
|
||||
encoder_seq_lens: Optional[np.ndarray] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UbatchSlice:
|
||||
@@ -193,6 +196,9 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.layer_names = layer_names
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
|
||||
@abstractmethod
|
||||
def build(self,
|
||||
|
||||
@@ -206,8 +206,9 @@ class XFormersAttentionMetadataBuilder(
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
assert XFORMERS_AVAILABLE
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self._num_decodes = 0
|
||||
self._num_decode_tokens = 0
|
||||
|
||||
Reference in New Issue
Block a user