[Encoder Decoder] Update Mllama to run with both FlashAttention and XFormers (#9982)
Signed-off-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
@@ -4,6 +4,8 @@ import pytest
|
||||
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||
BatchEncoding)
|
||||
|
||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
@@ -14,6 +16,8 @@ from ...utils import check_logprobs_close
|
||||
|
||||
_LIMIT_IMAGE_PER_PROMPT = 3
|
||||
|
||||
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"<|image|><|begin_of_text|>The meaning of the image is",
|
||||
@@ -221,6 +225,13 @@ def _run_test(
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Fixture to clear backend cache before each test."""
|
||||
_cached_get_attn_backend.cache_clear() # Clear the cache
|
||||
yield # This allows the test to run
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
@@ -244,20 +255,26 @@ def _run_test(
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
|
||||
model, sizes, dtype, max_tokens,
|
||||
num_logprobs) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
sizes=sizes,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
num_logprobs,
|
||||
attn_backend: _Backend) -> None:
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
if attn_backend == _Backend.FLASH_ATTN:
|
||||
# Flash Attention works only with bfloat16 data-type
|
||||
dtype = 'bfloat16'
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
sizes=sizes,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@@ -265,9 +282,10 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
||||
model, dtype, max_tokens,
|
||||
num_logprobs) -> None:
|
||||
model, dtype, max_tokens, num_logprobs,
|
||||
attn_backend: _Backend) -> None:
|
||||
|
||||
stop_sign = image_assets[0].pil_image
|
||||
cherry_blossom = image_assets[1].pil_image
|
||||
@@ -291,17 +309,20 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
||||
cherry_blossom.resize((512, 1024)),
|
||||
],
|
||||
])]
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
if attn_backend == _Backend.FLASH_ATTN:
|
||||
# Flash Attention works only with bfloat16 data-type
|
||||
dtype = 'bfloat16'
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@@ -309,8 +330,10 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
||||
dtype, max_tokens, num_logprobs) -> None:
|
||||
dtype, max_tokens, num_logprobs,
|
||||
attn_backend: _Backend) -> None:
|
||||
|
||||
stop_sign = image_assets[0].pil_image
|
||||
cherry_blossom = image_assets[1].pil_image
|
||||
@@ -325,14 +348,17 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
||||
[stop_sign],
|
||||
[stop_sign, cherry_blossom],
|
||||
])]
|
||||
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
if attn_backend == _Backend.FLASH_ATTN:
|
||||
# Flash Attention works only with bfloat16 data-type
|
||||
dtype = 'bfloat16'
|
||||
_run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
inputs,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user