[ROCm][Bugfix] Patch for the Multi-Modal Processor Test group (#29702)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2025-11-28 19:31:44 -06:00
committed by GitHub
parent c625d7b1c6
commit ea3370b428
4 changed files with 105 additions and 29 deletions

View File

@@ -30,6 +30,7 @@ from vllm.model_executor.models.interfaces import (
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.collection_utils import is_list_of
from vllm.utils.torch_utils import set_default_torch_dtype
@@ -176,6 +177,12 @@ def test_model_tensor_schema(model_id: str):
exist_overrides=model_info.hf_overrides,
)
# ROCm: Detect if model uses AWQ quantization and set appropriate dtype
if "awq" in model_id.lower() and current_platform.is_rocm():
dtype = "float16"
else:
dtype = model_info.dtype
model_config = ModelConfig(
model_id,
tokenizer=model_info.tokenizer or model_id,
@@ -187,7 +194,7 @@ def test_model_tensor_schema(model_id: str):
enable_prompt_embeds=model_info.require_embed_inputs,
enable_mm_embeds=model_info.require_embed_inputs,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype,
dtype=dtype,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)