[ROCm][CI] Fix spec decode profile assertion and logprob test determinism (#35043)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from tests.v1.sample.utils import (
|
||||
from vllm import SamplingParams
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...conftest import HfRunner, VllmRunner
|
||||
|
||||
@@ -31,6 +32,23 @@ SAMPLE = BatchLogprobsComposition.SAMPLE
|
||||
PROMPT = BatchLogprobsComposition.PROMPT
|
||||
SAMPLE_PROMPT = BatchLogprobsComposition.SAMPLE_PROMPT
|
||||
|
||||
# On ROCm, floating-point reductions in attention and GEMM kernels are
|
||||
# non-associative and sensitive to batch geometry. The ref LLM (no spec
|
||||
# decode, default scheduling) and the spec-decode LLM (chunked prefill,
|
||||
# different effective batch sizes) follow different reduction orders,
|
||||
# producing numerically divergent logprobs that get mis-attributed to
|
||||
# spec-decode incorrectness.
|
||||
#
|
||||
# Force LLM instances into an identical, deterministic execution
|
||||
# mode so the test isolates spec-decode correctness only:
|
||||
ROCM_DETERMINISM_KWARGS: dict = (
|
||||
dict(
|
||||
max_num_seqs=1,
|
||||
)
|
||||
if current_platform.is_rocm()
|
||||
else {}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
@@ -1035,6 +1053,7 @@ def test_spec_decode_logprobs(
|
||||
logprobs_mode=logprobs_mode,
|
||||
gpu_memory_utilization=0.4,
|
||||
enable_prefix_caching=False,
|
||||
**ROCM_DETERMINISM_KWARGS,
|
||||
)
|
||||
ref_results = ref_llm.generate(
|
||||
[prompt, prompt], [sampling_params, penalty_sampling_params]
|
||||
@@ -1064,6 +1083,7 @@ def test_spec_decode_logprobs(
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=32,
|
||||
enable_prefix_caching=False,
|
||||
**ROCM_DETERMINISM_KWARGS,
|
||||
)
|
||||
spec_results = spec_llm.generate(
|
||||
[prompt, prompt], [sampling_params, penalty_sampling_params]
|
||||
|
||||
Reference in New Issue
Block a user