[CI] Heavy refactoring of Voxtral multimodal audio model tests (#34294)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-13 22:04:29 -06:00
committed by GitHub
parent 60ca7981bc
commit de42abb366
11 changed files with 350 additions and 70 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import functools
import logging
import math
from dataclasses import replace
from functools import partial
@@ -30,11 +31,20 @@ from vllm.v1.attention.backend import (
subclass_attention_backend_with_overrides,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
try:
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionBackend
except ImportError:
AiterFlashAttentionBackend = None
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import AttentionSpec
from .utils import make_layers
logger = logging.getLogger(__name__)
CausalRMSNorm = partial(RMSNorm, eps=1e-5)
@@ -122,6 +132,13 @@ def create_whisper_attention_backend_with_block_pooling(
num_kv_heads=kv_cache_spec.num_kv_heads // block_pool_size,
)
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
# Override model_config-derived values with the actual
# encoder values from kv_cache_spec
self.num_heads_kv = kv_cache_spec.num_kv_heads
self.headdim = kv_cache_spec.head_size
# num_heads_q for the encoder is the same as num_kv_heads
# (no GQA in whisper encoder)
self.num_heads_q = kv_cache_spec.num_kv_heads
def build(
self,
@@ -192,13 +209,36 @@ def create_whisper_attention_backend_with_block_pooling(
output_block_scale,
)
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
_SUPPORTED_BACKENDS = tuple(
b
for b in (
AiterFlashAttentionBackend,
FlashAttentionBackend,
RocmAttentionBackend,
TritonAttentionBackend,
)
if b is not None
)
if not issubclass(underlying_attn_backend, _SUPPORTED_BACKENDS):
raise NotImplementedError(
f"{underlying_attn_backend} is not yet supported."
"Contributions to support more backends are much "
"appreciated."
)
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
logger.info(
"Using %s for Whisper causal attention with block pooling. "
"This backend was recently enabled for this model. "
"If you encounter any accuracy or performance issues, "
"please open an issue at "
"https://github.com/vllm-project/vllm/issues "
"with the [ROCm] tag so it can be triaged by the "
"appropriate team.",
underlying_attn_backend.get_name(),
)
attn_backend = subclass_attention_backend_with_overrides(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
@@ -209,14 +249,14 @@ def create_whisper_attention_backend_with_block_pooling(
block_size,
num_kv_heads,
head_size,
cache_dtype_str: (
2,
cache_dtype_str: underlying_attn_backend.get_kv_cache_shape(
num_blocks,
# we stretch each block by `block_pool_size`
block_size * block_pool_size,
num_kv_heads // block_pool_size,
head_size,
), # TODO: generalize to other backends
cache_dtype_str,
),
"forward_includes_kv_cache_update": True,
},
)