[CI] Heavy refactoring of Voxtral multimodal audio model tests (#34294)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import replace
|
||||
from functools import partial
|
||||
@@ -30,11 +31,20 @@ from vllm.v1.attention.backend import (
|
||||
subclass_attention_backend_with_overrides,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
|
||||
try:
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
|
||||
except ImportError:
|
||||
AiterFlashAttentionBackend = None
|
||||
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from .utils import make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CausalRMSNorm = partial(RMSNorm, eps=1e-5)
|
||||
|
||||
|
||||
@@ -122,6 +132,13 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
|
||||
)
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
# Override model_config-derived values with the actual
|
||||
# encoder values from kv_cache_spec
|
||||
self.num_heads_kv = kv_cache_spec.num_kv_heads
|
||||
self.headdim = kv_cache_spec.head_size
|
||||
# num_heads_q for the encoder is the same as num_kv_heads
|
||||
# (no GQA in whisper encoder)
|
||||
self.num_heads_q = kv_cache_spec.num_kv_heads
|
||||
|
||||
def build(
|
||||
self,
|
||||
@@ -192,13 +209,36 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
output_block_scale,
|
||||
)
|
||||
|
||||
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
|
||||
_SUPPORTED_BACKENDS = tuple(
|
||||
b
|
||||
for b in (
|
||||
AiterFlashAttentionBackend,
|
||||
FlashAttentionBackend,
|
||||
RocmAttentionBackend,
|
||||
TritonAttentionBackend,
|
||||
)
|
||||
if b is not None
|
||||
)
|
||||
|
||||
if not issubclass(underlying_attn_backend, _SUPPORTED_BACKENDS):
|
||||
raise NotImplementedError(
|
||||
f"{underlying_attn_backend} is not yet supported."
|
||||
"Contributions to support more backends are much "
|
||||
"appreciated."
|
||||
)
|
||||
|
||||
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
|
||||
logger.info(
|
||||
"Using %s for Whisper causal attention with block pooling. "
|
||||
"This backend was recently enabled for this model. "
|
||||
"If you encounter any accuracy or performance issues, "
|
||||
"please open an issue at "
|
||||
"https://github.com/vllm-project/vllm/issues "
|
||||
"with the [ROCm] tag so it can be triaged by the "
|
||||
"appropriate team.",
|
||||
underlying_attn_backend.get_name(),
|
||||
)
|
||||
|
||||
attn_backend = subclass_attention_backend_with_overrides(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
@@ -209,14 +249,14 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
cache_dtype_str: (
|
||||
2,
|
||||
cache_dtype_str: underlying_attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
# we stretch each block by `block_pool_size`
|
||||
block_size * block_pool_size,
|
||||
num_kv_heads // block_pool_size,
|
||||
head_size,
|
||||
), # TODO: generalize to other backends
|
||||
cache_dtype_str,
|
||||
),
|
||||
"forward_includes_kv_cache_update": True,
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user