Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -72,7 +72,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
Notes:
|
||||
- Use seeded stochastic sampling with a fixed seed to test determinism.
|
||||
- Outputs are intentionally longer and sampled at higher temperature/top_p
|
||||
to produce a more random-sounding phrase, yet remain deterministic by
|
||||
to produce a more random-sounding phrase, yet remain deterministic by
|
||||
seed.
|
||||
- Keep max_tokens and max_model_len bounded for speed and memory use.
|
||||
"""
|
||||
@@ -103,7 +103,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
seed=20240919,
|
||||
)
|
||||
|
||||
needle_prompt = ("There once was a ")
|
||||
needle_prompt = "There once was a "
|
||||
|
||||
llm_bs1 = None
|
||||
llm_bsN = None
|
||||
@@ -158,13 +158,16 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
|
||||
passes = num_trials - mismatches
|
||||
# Dump how many passed vs failed
|
||||
print(f"[determinism] total={num_trials}, passed={passes}, "
|
||||
f"failed={mismatches}, batch_size={batch_size}")
|
||||
print(
|
||||
f"[determinism] total={num_trials}, passed={passes}, "
|
||||
f"failed={mismatches}, batch_size={batch_size}"
|
||||
)
|
||||
|
||||
if mismatches > 0:
|
||||
pytest.fail(
|
||||
f"Nondeterministic outputs detected: {mismatches} failed out "
|
||||
f"of {num_trials} trials (batch_size={batch_size}).")
|
||||
f"of {num_trials} trials (batch_size={batch_size})."
|
||||
)
|
||||
|
||||
finally:
|
||||
# Ensure engines are shutdown to free GPU/VRAM across test sessions
|
||||
@@ -197,8 +200,7 @@ def _extract_step_logprobs(request_output):
|
||||
reason="Requires CUDA to match production inference path.",
|
||||
)
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
|
||||
|
||||
#model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
|
||||
# model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
@@ -230,8 +232,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
|
||||
assert len(outs) == 1
|
||||
step_logprobs = _extract_step_logprobs(outs[0])
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
pytest.skip(
|
||||
"Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test."
|
||||
)
|
||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||
|
||||
# BS=2: run prompts in a batch and collect logprobs per step for each
|
||||
@@ -242,24 +246,29 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
|
||||
for o in outs_batched:
|
||||
step_logprobs = _extract_step_logprobs(o)
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
pytest.skip(
|
||||
"Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test."
|
||||
)
|
||||
bs2_logprobs_per_prompt.append(step_logprobs)
|
||||
|
||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
|
||||
for i, (logprobs_bs1, logprobs_bs2) in enumerate(
|
||||
zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)):
|
||||
zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)
|
||||
):
|
||||
assert len(logprobs_bs1) == len(logprobs_bs2), (
|
||||
f"Different number of generation steps for prompt index {i}: "
|
||||
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)")
|
||||
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)"
|
||||
)
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
|
||||
assert a.shape == b.shape, (
|
||||
f"Logits shape mismatch at prompt {i}, step {t}: "
|
||||
f"{a.shape} vs {b.shape}")
|
||||
f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}"
|
||||
)
|
||||
# Bitwise exact equality.
|
||||
assert torch.equal(
|
||||
a, b), (f"Bitwise logprobs mismatch at prompt {i}, step {t} "
|
||||
f"(dtype={a.dtype}, shape={a.shape}).")
|
||||
assert torch.equal(a, b), (
|
||||
f"Bitwise logprobs mismatch at prompt {i}, step {t} "
|
||||
f"(dtype={a.dtype}, shape={a.shape})."
|
||||
)
|
||||
|
||||
|
||||
def LLM_with_max_seqs(
|
||||
|
||||
Reference in New Issue
Block a user