[Hardware][ROCM] using current_platform.is_rocm (#9642)
Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
@@ -6,8 +6,9 @@ from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||
BatchEncoding)
|
||||
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_hip
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from ...utils import check_logprobs_close
|
||||
@@ -24,7 +25,7 @@ models = ["google/paligemma-3b-mix-224"]
|
||||
# ROCm Triton FA can run into compilation issues with these models due to,
|
||||
# excessive use of shared memory. Use other backends in the meantime.
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
|
||||
@@ -70,7 +71,7 @@ def run_test(
|
||||
|
||||
All the image fixtures for the test are from IMAGE_ASSETS.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding MultiModalConfig as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
@@ -151,7 +152,7 @@ def run_test(
|
||||
pytest.param(
|
||||
"float",
|
||||
marks=pytest.mark.skipif(
|
||||
is_hip(),
|
||||
current_platform.is_rocm(),
|
||||
reason=
|
||||
"ROCm FA does not yet fully support 32-bit precision on PaliGemma")
|
||||
), "half"
|
||||
|
||||
@@ -12,7 +12,6 @@ from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_ImageAssets)
|
||||
@@ -56,7 +55,7 @@ if current_platform.is_cpu():
|
||||
# ROCm Triton FA can run into shared memory issues with these models,
|
||||
# use other backends in the meantime
|
||||
# FIXME (mattwong, gshtrasb, hongxiayan)
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user