[ROCm][CI] Fix flaky embedding chat test by using tolerance-based comparison (#35050)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -58,13 +58,19 @@ if current_platform.is_rocm():
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
|
||||
# On ROCm, floating-point reductions in attention and GEMM kernels are
|
||||
# non-associative and sensitive to batch geometry. Force LLM instances
|
||||
# into an identical, deterministic execution mode:
|
||||
ROCM_DETERMINISM_ARGS: list[str] = (
|
||||
["--max-num-seqs", "1"] if current_platform.is_rocm() else []
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
DTYPE,
|
||||
"--enforce-eager",
|
||||
@@ -72,12 +78,9 @@ def server():
|
||||
"512",
|
||||
"--chat-template",
|
||||
DUMMY_CHAT_TEMPLATE,
|
||||
*ROCM_DETERMINISM_ARGS,
|
||||
]
|
||||
|
||||
# ROCm: Use Flex Attention to support encoder-only self-attention.
|
||||
if current_platform.is_rocm():
|
||||
args.extend(["--attention-backend", "FLEX_ATTENTION"])
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
@@ -343,8 +346,15 @@ async def test_chat_request(
|
||||
assert chat_embeddings.id is not None
|
||||
assert completion_embeddings.id is not None
|
||||
assert chat_embeddings.created <= completion_embeddings.created
|
||||
assert chat_embeddings.model_dump(exclude={"id", "created"}) == (
|
||||
completion_embeddings.model_dump(exclude={"id", "created"})
|
||||
# Use tolerance-based comparison for embeddings
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[d.embedding for d in chat_embeddings.data],
|
||||
embeddings_1_lst=[d.embedding for d in completion_embeddings.data],
|
||||
name_0="chat",
|
||||
name_1="completion",
|
||||
)
|
||||
assert chat_embeddings.model_dump(exclude={"id", "created", "data"}) == (
|
||||
completion_embeddings.model_dump(exclude={"id", "created", "data"})
|
||||
)
|
||||
|
||||
# test add_generation_prompt
|
||||
|
||||
Reference in New Issue
Block a user