Files
vllm/tests/models/multimodal/conftest.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

52 lines
1.7 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pytest configuration for vLLM multimodal tests."""
import warnings
import torch
from vllm.platforms import current_platform
def pytest_collection_modifyitems(config, items):
"""Configure ROCm-specific settings based on collected tests."""
if not current_platform.is_rocm():
return
skip_patterns = ["test_granite_speech.py"]
if any(pattern in str(arg) for arg in config.args for pattern in skip_patterns):
return
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
warnings.warn(
"ROCm: Disabled flash_sdp and mem_efficient_sdp, enabled math_sdp "
"to avoid HuggingFace Transformers accuracy issues",
UserWarning,
stacklevel=1,
)
def patch_hf_vision_attn_for_rocm(model):
"""Force SDPA for HF vision encoders on ROCm.
HF's flash_attention_2 has accuracy issues on ROCm that bypass
torch.backends.cuda settings. This forces SDPA which then uses
math_sdp via the pytest_collection_modifyitems settings.
"""
if not current_platform.is_rocm():
return
inner = getattr(model, "model", model)
if hasattr(inner, "vision_embedding"):
vit = inner.vision_embedding[0]
for layer in vit.encoder.layers:
if hasattr(layer, "self_attn"):
layer.self_attn.vision_config._attn_implementation = "sdpa"