[ROCm][CI] Fix flaky embedding chat test by using tolerance-based comparison (#35050)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-22 03:03:44 -06:00
committed by GitHub
parent 40f88d8318
commit a8a47c17b6

View File

@@ -58,13 +58,19 @@ if current_platform.is_rocm():
torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
# On ROCm, floating-point reductions in attention and GEMM kernels are
# non-associative and sensitive to batch geometry. Force LLM instances
# into an identical, deterministic execution mode:
ROCM_DETERMINISM_ARGS: list[str] = (
["--max-num-seqs", "1"] if current_platform.is_rocm() else []
)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = [ args = [
"--runner", "--runner",
"pooling", "pooling",
# use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
DTYPE, DTYPE,
"--enforce-eager", "--enforce-eager",
@@ -72,12 +78,9 @@ def server():
"512", "512",
"--chat-template", "--chat-template",
DUMMY_CHAT_TEMPLATE, DUMMY_CHAT_TEMPLATE,
*ROCM_DETERMINISM_ARGS,
] ]
# ROCm: Use Flex Attention to support encoder-only self-attention.
if current_platform.is_rocm():
args.extend(["--attention-backend", "FLEX_ATTENTION"])
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
@@ -343,8 +346,15 @@ async def test_chat_request(
assert chat_embeddings.id is not None assert chat_embeddings.id is not None
assert completion_embeddings.id is not None assert completion_embeddings.id is not None
assert chat_embeddings.created <= completion_embeddings.created assert chat_embeddings.created <= completion_embeddings.created
assert chat_embeddings.model_dump(exclude={"id", "created"}) == ( # Use tolerance-based comparison for embeddings
completion_embeddings.model_dump(exclude={"id", "created"}) check_embeddings_close(
embeddings_0_lst=[d.embedding for d in chat_embeddings.data],
embeddings_1_lst=[d.embedding for d in completion_embeddings.data],
name_0="chat",
name_1="completion",
)
assert chat_embeddings.model_dump(exclude={"id", "created", "data"}) == (
completion_embeddings.model_dump(exclude={"id", "created", "data"})
) )
# test add_generation_prompt # test add_generation_prompt