[ROCm][CI][Bugfix] Multi-Modal Model Support Fixes and Attention Backend Improvements (#30270)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import SupportsMRoPE, SupportsMultiModal
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.multimodal import MultiModalKwargsItems
|
||||
@@ -36,6 +37,7 @@ from vllm.multimodal.inputs import (
|
||||
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -52,6 +54,8 @@ DYNAMIC_ARG_DIMS = {
|
||||
"inputs_embeds": 0,
|
||||
}
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self):
|
||||
@@ -345,8 +349,29 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
|
||||
num_image_patches = kwargs.pop("num_image_patches")
|
||||
kwargs.pop("token_type_ids", None) # used only in `forward`
|
||||
|
||||
if pixel_values is not None:
|
||||
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
|
||||
# ROCm: Force math SDP backend for vision encoder to avoid accuracy issues
|
||||
# with flash_sdp and mem_efficient_sdp
|
||||
if current_platform.is_rocm():
|
||||
# TODO: [ROCm] Fix accuracy issues with flash backend
|
||||
logger.debug(
|
||||
"ROCm platform detected. Forcing math SDP backend "
|
||||
"for vision encoder. Currently ROCm platform has "
|
||||
"accuracy issues with `flash_sdp` and"
|
||||
"`mem_efficient_sdp` backends. See issue: "
|
||||
"https://github.com/vllm-project/vllm/issues/30167"
|
||||
)
|
||||
with torch.nn.attention.sdpa_kernel(
|
||||
backends=[torch.nn.attention.SDPBackend.MATH]
|
||||
):
|
||||
vision_embeddings = self.model.get_image_features(
|
||||
pixel_values, **kwargs
|
||||
)
|
||||
else:
|
||||
vision_embeddings = self.model.get_image_features(
|
||||
pixel_values, **kwargs
|
||||
)
|
||||
|
||||
if isinstance(vision_embeddings, torch.Tensor):
|
||||
if vision_embeddings.ndim == 2:
|
||||
@@ -364,6 +389,11 @@ class MultiModalMixin(SupportsMultiModal, SupportsMRoPE):
|
||||
]
|
||||
|
||||
return vision_embeddings
|
||||
else:
|
||||
logger.debug(
|
||||
"No pixel values or image embeddings provided for multimodal embedding."
|
||||
)
|
||||
return None
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user