[Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers (#9982)
Signed-off-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import List, Optional, Tuple
|
||||
import pytest
|
||||
from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.attention.selector import (_Backend,
|
||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
@@ -34,6 +34,13 @@ def vllm_to_hf_output(
|
||||
return output_ids, hf_output_str, out_logprobs
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Fixture to clear backend cache before each test."""
|
||||
_cached_get_attn_backend.cache_clear() # Clear the cache
|
||||
yield # This allows the test to run
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
|
||||
Reference in New Issue
Block a user