[ROCm][Bugfix] Patch for the Multi-Modal Processor Test group (#29702)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -30,6 +30,7 @@ from vllm.model_executor.models.interfaces import (
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
@@ -176,6 +177,12 @@ def test_model_tensor_schema(model_id: str):
|
||||
exist_overrides=model_info.hf_overrides,
|
||||
)
|
||||
|
||||
# ROCm: Detect if model uses AWQ quantization and set appropriate dtype
|
||||
if "awq" in model_id.lower() and current_platform.is_rocm():
|
||||
dtype = "float16"
|
||||
else:
|
||||
dtype = model_info.dtype
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
tokenizer=model_info.tokenizer or model_id,
|
||||
@@ -187,7 +194,7 @@ def test_model_tensor_schema(model_id: str):
|
||||
enable_prompt_embeds=model_info.require_embed_inputs,
|
||||
enable_mm_embeds=model_info.require_embed_inputs,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
Reference in New Issue
Block a user