[Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers (#9982)

Signed-off-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
sroy745
2024-11-12 10:53:57 -08:00
committed by GitHub
parent 7c65527918
commit b41fb9d3b1
5 changed files with 117 additions and 80 deletions

View File

@@ -7,7 +7,7 @@ from typing import List, Optional, Tuple
import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.attention.selector import (_Backend,
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
@@ -34,6 +34,13 @@ def vllm_to_hf_output(
return output_ids, hf_output_str, out_logprobs
@pytest.fixture(autouse=True)
def clear_cache():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend.cache_clear() # Clear the cache
yield # This allows the test to run
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)