[CI] improve embed testing (#18747)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -13,9 +13,6 @@ from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||
|
||||
from .registry import HF_EXAMPLE_MODELS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..conftest import HfRunner
|
||||
|
||||
TokensText = tuple[list[int], str]
|
||||
|
||||
|
||||
@@ -337,22 +334,3 @@ class EmbedModelInfo(NamedTuple):
|
||||
architecture: str = ""
|
||||
dtype: str = "auto"
|
||||
enable_test: bool = True
|
||||
|
||||
|
||||
def run_embedding_correctness_test(
|
||||
hf_model: "HfRunner",
|
||||
inputs: list[str],
|
||||
vllm_outputs: Sequence[list[float]],
|
||||
dimensions: Optional[int] = None,
|
||||
):
|
||||
hf_outputs = hf_model.encode(inputs)
|
||||
if dimensions:
|
||||
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user