[ROCm][CI] Fix spec decode profile assertion and logprob test determinism (#35043)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-23 07:05:54 -06:00
committed by GitHub
parent aa08a30fc9
commit 5f68464f92

View File

@@ -20,6 +20,7 @@ from tests.v1.sample.utils import (
from vllm import SamplingParams
from vllm.config.model import LogprobsMode
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform
from ...conftest import HfRunner, VllmRunner
@@ -31,6 +32,23 @@ SAMPLE = BatchLogprobsComposition.SAMPLE
PROMPT = BatchLogprobsComposition.PROMPT
SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
# On ROCm, floating-point reductions in attention and GEMM kernels are
# non-associative and sensitive to batch geometry. The ref LLM (no spec
# decode, default scheduling) and the spec-decode LLM (chunked prefill,
# different effective batch sizes) follow different reduction orders,
# producing numerically divergent logprobs that get mis-attributed to
# spec-decode incorrectness.
#
# Force LLM instances into an identical, deterministic execution
# mode so the test isolates spec-decode correctness only:
ROCM_DETERMINISM_KWARGS: dict = (
dict(
max_num_seqs=1,
)
if current_platform.is_rocm()
else {}
)
@pytest.fixture(
scope="module",
@@ -1035,6 +1053,7 @@ def test_spec_decode_logprobs(
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
enable_prefix_caching=False,
**ROCM_DETERMINISM_KWARGS,
)
ref_results = ref_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]
@@ -1064,6 +1083,7 @@ def test_spec_decode_logprobs(
enable_chunked_prefill=True,
max_num_batched_tokens=32,
enable_prefix_caching=False,
**ROCM_DETERMINISM_KWARGS,
)
spec_results = spec_llm.generate(
[prompt, prompt], [sampling_params, penalty_sampling_params]