[FEAT] [ROCm] [Embedding] Add encoder-only model support into ROCm Flash Attention to enable embedding models. (#14664)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-03-13 00:31:19 +08:00
committed by GitHub
parent d9f83d6206
commit 916836bbfb
7 changed files with 118 additions and 50 deletions

View File

@@ -4,10 +4,27 @@ import pytest
import torch.nn.functional as F
from transformers import AutoModelForVision2Seq
from vllm.platforms import current_platform
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
from ..utils import check_embeddings_close
# Llava Next embedding implementation is only supported by CUDA.
# If run on ROCm, hf_model.model.resize_token_embeddings will
# cause the following error:
# RuntimeError: Calling torch.linalg.cholesky on a CUDA tensor
# requires compiling PyTorch with MAGMA. Please use PyTorch
# built with MAGMA support.
# If run on CPU, hf_model.model.resize_token_embeddings will
# cause the following error:
# RuntimeError: Calling torch.linalg.cholesky on a CPU tensor
# requires compiling PyTorch with LAPACK. Please use PyTorch
# built with LAPACK support.
pytestmark = pytest.mark.skipif(
not current_platform.is_cuda(),
reason="Llava Next model uses op that is only supported in CUDA")
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
HF_TEXT_PROMPTS = [