Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,19 +7,20 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
NUM_SAMPLE_LOGPROBS_UNDER_TEST,
|
||||
STOP_STRINGS,
|
||||
DummyOutputProcessorTestVectors,
|
||||
MockEngineCore)
|
||||
from tests.v1.engine.utils import (
|
||||
NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
NUM_SAMPLE_LOGPROBS_UNDER_TEST,
|
||||
STOP_STRINGS,
|
||||
DummyOutputProcessorTestVectors,
|
||||
MockEngineCore,
|
||||
)
|
||||
from vllm import PoolingParams
|
||||
from vllm.logprobs import PromptLogprobs, SampleLogprobs
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||
RequestOutputCollector)
|
||||
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
|
||||
@@ -40,33 +41,34 @@ def _ref_convert_id_to_token(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_output_kind",
|
||||
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
||||
dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens)
|
||||
"request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
def test_incremental_detokenization(
|
||||
request_output_kind: RequestOutputKind, dummy_test_vectors
|
||||
):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||
engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
|
||||
|
||||
# Make N requests.
|
||||
requests = [
|
||||
EngineCoreRequest(request_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
output_kind=request_output_kind,
|
||||
stop=[],
|
||||
include_stop_str_in_output=False,
|
||||
),
|
||||
pooling_params=None)
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
output_kind=request_output_kind,
|
||||
stop=[],
|
||||
include_stop_str_in_output=False,
|
||||
),
|
||||
pooling_params=None,
|
||||
)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
@@ -102,8 +104,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
||||
|
||||
# Confirmed tracked values matches what we expected.
|
||||
for idx, (ref_gen_str, ref_gen_toks) in enumerate(
|
||||
zip(dummy_test_vectors.generation_strings,
|
||||
dummy_test_vectors.generation_tokens)):
|
||||
zip(dummy_test_vectors.generation_strings, dummy_test_vectors.generation_tokens)
|
||||
):
|
||||
gen_str = gen_strings[f"request-{idx}"]
|
||||
gen_toks = gen_tokens[f"request-{idx}"]
|
||||
|
||||
@@ -134,9 +136,11 @@ def _validate_logprobs(
|
||||
ref_prompt_logprobs = dtv.prompt_logprobs[req_idx]
|
||||
if num_sample_logprobs is not None:
|
||||
# Validate sample logprobs
|
||||
assert logprobs is not None, (f"Request {req_id} requires sample"
|
||||
" logprobs but sample logprobs are"
|
||||
" None.")
|
||||
assert logprobs is not None, (
|
||||
f"Request {req_id} requires sample"
|
||||
" logprobs but sample logprobs are"
|
||||
" None."
|
||||
)
|
||||
# Require num sampled tokens to match num
|
||||
# sampled logprobs - especially important
|
||||
# to check since the detokenizer can cause
|
||||
@@ -147,44 +151,51 @@ def _validate_logprobs(
|
||||
assert num_new_tokens == len_sample_logprobs, (
|
||||
f"Request {req_id} has {num_new_tokens}"
|
||||
" completion tokens but has"
|
||||
f" {len_sample_logprobs} sample logprobs.")
|
||||
f" {len_sample_logprobs} sample logprobs."
|
||||
)
|
||||
ref_cumulative_logprob = 0.0
|
||||
for idx, (sampled_token,
|
||||
pos_logprob_dict) in enumerate(zip(new_tokens,
|
||||
logprobs)):
|
||||
for idx, (sampled_token, pos_logprob_dict) in enumerate(
|
||||
zip(new_tokens, logprobs)
|
||||
):
|
||||
# Break out the reference log probability value &
|
||||
# logprob token id tensors associated with this
|
||||
# position in the completion. Also break out the
|
||||
# sampled token ranks
|
||||
(ref_pos_logprob_toks, ref_pos_logprob_vals,
|
||||
ref_sampled_token_rank) = ref_logprobs[idx]
|
||||
(ref_pos_logprob_toks, ref_pos_logprob_vals, ref_sampled_token_rank) = (
|
||||
ref_logprobs[idx]
|
||||
)
|
||||
# For each position in the completion sequence,
|
||||
# ensure the actual sampled token is among the
|
||||
# logprobs
|
||||
assert sampled_token in pos_logprob_dict, (
|
||||
f"Sampled token {sampled_token} not"
|
||||
f" present in logprob at index {idx}")
|
||||
f" present in logprob at index {idx}"
|
||||
)
|
||||
|
||||
# Validate number of sample logprobs
|
||||
num_lp_toks = len(pos_logprob_dict)
|
||||
assert (num_lp_toks == num_sample_logprobs
|
||||
or num_lp_toks == num_sample_logprobs +
|
||||
1), ("Valid numbers of sample logprobs are"
|
||||
f" {num_sample_logprobs} or"
|
||||
f" {num_sample_logprobs+1} but"
|
||||
f" {num_lp_toks} logprobs found at"
|
||||
f" position {idx}. Logprobs dict:"
|
||||
f" {pos_logprob_dict}")
|
||||
assert (
|
||||
num_lp_toks == num_sample_logprobs
|
||||
or num_lp_toks == num_sample_logprobs + 1
|
||||
), (
|
||||
"Valid numbers of sample logprobs are"
|
||||
f" {num_sample_logprobs} or"
|
||||
f" {num_sample_logprobs + 1} but"
|
||||
f" {num_lp_toks} logprobs found at"
|
||||
f" position {idx}. Logprobs dict:"
|
||||
f" {pos_logprob_dict}"
|
||||
)
|
||||
|
||||
# Validate sampled token logprob rank
|
||||
smp_lp = pos_logprob_dict[sampled_token]
|
||||
smp_lp_rank = smp_lp.rank
|
||||
assert (ref_sampled_token_rank == smp_lp_rank), (
|
||||
assert ref_sampled_token_rank == smp_lp_rank, (
|
||||
"Sampled token logprob rank"
|
||||
f" {smp_lp_rank} does not match"
|
||||
" correct value"
|
||||
f" {ref_sampled_token_rank}"
|
||||
f" in Logprob {smp_lp}")
|
||||
f" in Logprob {smp_lp}"
|
||||
)
|
||||
|
||||
# Validate that the logprob processor yields
|
||||
# the correct log probabilities and valid
|
||||
@@ -198,7 +209,8 @@ def _validate_logprobs(
|
||||
ref_tok_id = ref_pos_logprob_toks[jdx]
|
||||
assert ref_tok_id in pos_logprob_dict, (
|
||||
f"Expected token {ref_tok_id} to be"
|
||||
f" in logprob dict but it is not.")
|
||||
f" in logprob dict but it is not."
|
||||
)
|
||||
|
||||
# Extract actually-generated logprob
|
||||
# info
|
||||
@@ -208,40 +220,43 @@ def _validate_logprobs(
|
||||
|
||||
# A "top" (rank 1) logprob must be
|
||||
# present
|
||||
rank_one_appears = (True
|
||||
if lp_rank == 1 else rank_one_appears)
|
||||
rank_one_appears = True if lp_rank == 1 else rank_one_appears
|
||||
|
||||
# Rank must be >= 1
|
||||
assert lp_rank >= 1, (f"Logprob {lp} has invalid"
|
||||
f" rank {lp_rank} < 1."
|
||||
f" Logprob dict: {pos_logprob_dict}")
|
||||
assert lp_rank >= 1, (
|
||||
f"Logprob {lp} has invalid"
|
||||
f" rank {lp_rank} < 1."
|
||||
f" Logprob dict: {pos_logprob_dict}"
|
||||
)
|
||||
|
||||
# Validate log probability
|
||||
assert math.isclose(lp_val, ref_lp_val), (
|
||||
f"Token id {ref_tok_id} appears in logprobs dict"
|
||||
f" at position {idx} in completion with log"
|
||||
f" probability {lp_val} but {ref_lp_val} was"
|
||||
f" expected. Logprob: {lp}")
|
||||
f" expected. Logprob: {lp}"
|
||||
)
|
||||
|
||||
assert rank_one_appears, (f"No Logprob has rank 1"
|
||||
" in the following Logprob"
|
||||
f" dict: {pos_logprob_dict}")
|
||||
assert rank_one_appears, (
|
||||
f"No Logprob has rank 1"
|
||||
" in the following Logprob"
|
||||
f" dict: {pos_logprob_dict}"
|
||||
)
|
||||
|
||||
# Validate logprobs detokenization
|
||||
for lp_tok in pos_logprob_dict:
|
||||
# Confirm that sample logprob decoded token matches
|
||||
# the logprob token id at this sequence position
|
||||
decoded_token = pos_logprob_dict[lp_tok].decoded_token
|
||||
ref_decoded_token = _ref_convert_id_to_token(
|
||||
dtv.tokenizer, lp_tok)
|
||||
ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, lp_tok)
|
||||
assert decoded_token == ref_decoded_token, (
|
||||
f"Sampled logprob token id {lp_tok} decodes to"
|
||||
f" {ref_decoded_token} but Logprob decoded"
|
||||
f" token is {decoded_token} instead"
|
||||
f" (at position {idx})")
|
||||
f" (at position {idx})"
|
||||
)
|
||||
|
||||
ref_cumulative_logprob += pos_logprob_dict[
|
||||
sampled_token].logprob
|
||||
ref_cumulative_logprob += pos_logprob_dict[sampled_token].logprob
|
||||
# Assert that cumulative logprobs are correct
|
||||
assert math.isclose(cumulative_logprob, ref_cumulative_logprob)
|
||||
else:
|
||||
@@ -254,7 +269,8 @@ def _validate_logprobs(
|
||||
assert prompt_logprobs is not None, (
|
||||
f"Request {req_id} requires prompt"
|
||||
" logprobs but prompt logprobs are"
|
||||
" None.")
|
||||
" None."
|
||||
)
|
||||
# Require num prompt tokens to match num
|
||||
# prompt logprobs
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
@@ -262,56 +278,70 @@ def _validate_logprobs(
|
||||
assert num_prompt_tokens == len_prompt_logprobs, (
|
||||
f"Request {req_id} has {num_prompt_tokens}"
|
||||
" prompt tokens but has"
|
||||
f" {len_prompt_logprobs} prompt logprobs.")
|
||||
f" {len_prompt_logprobs} prompt logprobs."
|
||||
)
|
||||
# First prompt logprob is None
|
||||
first_plp_dict = prompt_logprobs[0]
|
||||
assert first_plp_dict is None, (
|
||||
f"Request {req_id} first prompt logprob"
|
||||
f" should be None but has following value"
|
||||
f" instead: {first_plp_dict}")
|
||||
f" instead: {first_plp_dict}"
|
||||
)
|
||||
# Break out the reference prompt log prob value &
|
||||
# logprob token id matrices for the whole prompt.
|
||||
# Also break out the prompt token rank vector
|
||||
(ref_prompt_logprob_toks, ref_prompt_logprob_vals,
|
||||
ref_prompt_token_ranks) = ref_prompt_logprobs
|
||||
(
|
||||
ref_prompt_logprob_toks,
|
||||
ref_prompt_logprob_vals,
|
||||
ref_prompt_token_ranks,
|
||||
) = ref_prompt_logprobs
|
||||
for idx, (prompt_token, pos_logprob_dict) in enumerate(
|
||||
zip(prompt_token_ids[1:], prompt_logprobs[1:])):
|
||||
|
||||
zip(prompt_token_ids[1:], prompt_logprobs[1:])
|
||||
):
|
||||
# Break out the reference prompt log prob value
|
||||
# vector, prompt logprob token id vector, and
|
||||
# prompt token rank at the current position.
|
||||
(ref_pos_prompt_logprob_toks, ref_pos_prompt_logprob_vals,
|
||||
ref_pos_prompt_token_rank) = (ref_prompt_logprob_toks[idx, :],
|
||||
ref_prompt_logprob_vals[idx, :],
|
||||
ref_prompt_token_ranks[idx])
|
||||
(
|
||||
ref_pos_prompt_logprob_toks,
|
||||
ref_pos_prompt_logprob_vals,
|
||||
ref_pos_prompt_token_rank,
|
||||
) = (
|
||||
ref_prompt_logprob_toks[idx, :],
|
||||
ref_prompt_logprob_vals[idx, :],
|
||||
ref_prompt_token_ranks[idx],
|
||||
)
|
||||
|
||||
# For each position in the prompt sequence,
|
||||
# ensure the actual prompt token is among the
|
||||
# logprobs
|
||||
assert prompt_token in pos_logprob_dict, (
|
||||
f"Prompt token {prompt_token} not"
|
||||
f" present in logprob at index {idx}")
|
||||
f"Prompt token {prompt_token} not present in logprob at index {idx}"
|
||||
)
|
||||
# Validate number of prompt logprobs
|
||||
num_plp_toks = len(pos_logprob_dict)
|
||||
assert (num_plp_toks == num_prompt_logprobs
|
||||
or num_plp_toks == num_prompt_logprobs +
|
||||
1), ("Valid numbers of prompt logprobs are"
|
||||
f" {num_prompt_logprobs} or"
|
||||
f" {num_prompt_logprobs+1} but"
|
||||
f" {num_plp_toks} logprobs found at"
|
||||
f" position {idx}. Logprobs dict:"
|
||||
f" {pos_logprob_dict}")
|
||||
assert (
|
||||
num_plp_toks == num_prompt_logprobs
|
||||
or num_plp_toks == num_prompt_logprobs + 1
|
||||
), (
|
||||
"Valid numbers of prompt logprobs are"
|
||||
f" {num_prompt_logprobs} or"
|
||||
f" {num_prompt_logprobs + 1} but"
|
||||
f" {num_plp_toks} logprobs found at"
|
||||
f" position {idx}. Logprobs dict:"
|
||||
f" {pos_logprob_dict}"
|
||||
)
|
||||
|
||||
# Validate prompt token logprob rank
|
||||
prmpt_tok_lp = pos_logprob_dict[prompt_token]
|
||||
prmpt_tok_lp_rank = prmpt_tok_lp.rank
|
||||
ref_prmpt_tok_lp_rank = ref_pos_prompt_token_rank
|
||||
assert (ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank), (
|
||||
assert ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank, (
|
||||
"Prompt token logprob rank"
|
||||
f" {prmpt_tok_lp_rank} does not match"
|
||||
" correct value"
|
||||
f" {ref_prmpt_tok_lp_rank}"
|
||||
f" in Logprob {prmpt_tok_lp}")
|
||||
f" in Logprob {prmpt_tok_lp}"
|
||||
)
|
||||
|
||||
# Validate that the logprob processor yields
|
||||
# the correct prompt log probs and valid
|
||||
@@ -325,7 +355,8 @@ def _validate_logprobs(
|
||||
ref_tok_id = int(ref_pos_prompt_logprob_toks[jdx])
|
||||
assert ref_tok_id in pos_logprob_dict, (
|
||||
f"Expected token {ref_tok_id} to be"
|
||||
f" in logprob dict but it is not.")
|
||||
f" in logprob dict but it is not."
|
||||
)
|
||||
|
||||
# Extract actually-generated logprob
|
||||
# info
|
||||
@@ -335,87 +366,93 @@ def _validate_logprobs(
|
||||
|
||||
# A "top" (rank 1) logprob must be
|
||||
# present
|
||||
rank_one_appears = (True
|
||||
if plp_rank == 1 else rank_one_appears)
|
||||
rank_one_appears = True if plp_rank == 1 else rank_one_appears
|
||||
|
||||
# Rank must be >= 1
|
||||
assert plp_rank >= 1, (
|
||||
f"Logprob {plp} has invalid"
|
||||
f" rank {plp_rank} < 1."
|
||||
f" Logprob dict: {pos_logprob_dict}")
|
||||
f" Logprob dict: {pos_logprob_dict}"
|
||||
)
|
||||
|
||||
# Validate log probability
|
||||
assert math.isclose(plp_val, ref_plp_val), (
|
||||
f"Token id {ref_tok_id} appears in logprobs dict"
|
||||
f" at position {idx} in completion with log"
|
||||
f" probability {plp_val} but {ref_plp_val} was"
|
||||
f" expected. Logprob: {plp}")
|
||||
f" expected. Logprob: {plp}"
|
||||
)
|
||||
|
||||
assert rank_one_appears, (f"No Logprob has rank 1"
|
||||
" in the following Logprob"
|
||||
f" dict: {pos_logprob_dict}")
|
||||
assert rank_one_appears, (
|
||||
f"No Logprob has rank 1"
|
||||
" in the following Logprob"
|
||||
f" dict: {pos_logprob_dict}"
|
||||
)
|
||||
|
||||
# Validate prompt logprob detokenization
|
||||
for plp_tok in pos_logprob_dict:
|
||||
# Confirm that prompt logprob decoded token matches
|
||||
# the logprob token id at this sequence position
|
||||
decoded_token = pos_logprob_dict[plp_tok].decoded_token
|
||||
ref_decoded_token = _ref_convert_id_to_token(
|
||||
dtv.tokenizer, plp_tok)
|
||||
ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, plp_tok)
|
||||
assert decoded_token == ref_decoded_token, (
|
||||
f"Prompt logprob token id {plp_tok} decodes to"
|
||||
f" {ref_decoded_token} but Logprob decoded"
|
||||
f" token is {decoded_token} instead"
|
||||
f" (at position {idx})")
|
||||
f" (at position {idx})"
|
||||
)
|
||||
else:
|
||||
# Prompt logprobs disabled for this request
|
||||
assert prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_output_kind",
|
||||
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.parametrize("num_sample_logprobs",
|
||||
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
||||
@pytest.mark.parametrize("num_prompt_logprobs",
|
||||
[None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
|
||||
def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
num_sample_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int],
|
||||
dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=False)
|
||||
"request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
||||
@pytest.mark.parametrize("num_prompt_logprobs", [None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
|
||||
def test_logprobs_processor(
|
||||
request_output_kind: RequestOutputKind,
|
||||
num_sample_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int],
|
||||
dummy_test_vectors,
|
||||
):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
generated_logprobs_raw=None if num_sample_logprobs is None else
|
||||
dummy_test_vectors.generation_logprobs,
|
||||
generated_logprobs_raw=None
|
||||
if num_sample_logprobs is None
|
||||
else dummy_test_vectors.generation_logprobs,
|
||||
prompt_logprobs_raw=None
|
||||
if num_prompt_logprobs is None else dummy_test_vectors.prompt_logprobs)
|
||||
if num_prompt_logprobs is None
|
||||
else dummy_test_vectors.prompt_logprobs,
|
||||
)
|
||||
|
||||
# Make N requests.
|
||||
request_id_list = [
|
||||
f"request-{idx}"
|
||||
for idx in range(len(dummy_test_vectors.prompt_strings))
|
||||
f"request-{idx}" for idx in range(len(dummy_test_vectors.prompt_strings))
|
||||
]
|
||||
requests = [
|
||||
EngineCoreRequest(request_id=request_id_list[idx],
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
output_kind=request_output_kind,
|
||||
stop=[],
|
||||
include_stop_str_in_output=False,
|
||||
logprobs=num_sample_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
),
|
||||
pooling_params=None)
|
||||
EngineCoreRequest(
|
||||
request_id=request_id_list[idx],
|
||||
prompt_token_ids=prompt_tokens,
|
||||
mm_features=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
output_kind=request_output_kind,
|
||||
stop=[],
|
||||
include_stop_str_in_output=False,
|
||||
logprobs=num_sample_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
),
|
||||
pooling_params=None,
|
||||
)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
@@ -446,7 +483,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
prompt_logprobs = request_output.prompt_logprobs
|
||||
logprobs = request_output.outputs[0].logprobs
|
||||
gen_cumulative_logprobs[request_id] = request_output.outputs[
|
||||
0].cumulative_logprob
|
||||
0
|
||||
].cumulative_logprob
|
||||
if request_id not in gen_logprobs:
|
||||
# Start tracking sample and prompt logprobs for this request
|
||||
gen_tokens[request_id] = new_tokens
|
||||
@@ -463,10 +501,16 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
plp.extend(prompt_logprobs)
|
||||
|
||||
# Confirmed tracked logprobs match what we expect
|
||||
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
|
||||
gen_cumulative_logprobs, dummy_test_vectors,
|
||||
request_id_list, num_sample_logprobs,
|
||||
num_prompt_logprobs)
|
||||
_validate_logprobs(
|
||||
gen_tokens,
|
||||
gen_logprobs,
|
||||
gen_prompt_logprobs,
|
||||
gen_cumulative_logprobs,
|
||||
dummy_test_vectors,
|
||||
request_id_list,
|
||||
num_sample_logprobs,
|
||||
num_prompt_logprobs,
|
||||
)
|
||||
|
||||
assert output_processor.get_num_unfinished_requests() == 0
|
||||
assert not output_processor.has_unfinished_requests()
|
||||
@@ -474,15 +518,23 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"include_stop_str_in_output,stop_token_type,ignore_eos,num_sample_logprobs",
|
||||
[(False, "stop_token_ids", False, None),
|
||||
(True, "stop_token_ids", False, None),
|
||||
(False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
|
||||
(True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
|
||||
(False, "eos_token_id", False, None), (True, "eos_token_id", False, None),
|
||||
(False, "eos_token_id", True, None)])
|
||||
def test_stop_token(include_stop_str_in_output: bool,
|
||||
num_sample_logprobs: Optional[int], stop_token_type: str,
|
||||
ignore_eos: bool, dummy_test_vectors):
|
||||
[
|
||||
(False, "stop_token_ids", False, None),
|
||||
(True, "stop_token_ids", False, None),
|
||||
(False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
|
||||
(True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
|
||||
(False, "eos_token_id", False, None),
|
||||
(True, "eos_token_id", False, None),
|
||||
(False, "eos_token_id", True, None),
|
||||
],
|
||||
)
|
||||
def test_stop_token(
|
||||
include_stop_str_in_output: bool,
|
||||
num_sample_logprobs: Optional[int],
|
||||
stop_token_type: str,
|
||||
ignore_eos: bool,
|
||||
dummy_test_vectors,
|
||||
):
|
||||
"""Test output processor EOS/stop token handling.
|
||||
|
||||
Send mock engine core request to mock engine core and pass core outputs
|
||||
@@ -523,9 +575,10 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
dummy_test_vectors: dummy engine core outputs and other data structures
|
||||
"""
|
||||
model_id = dummy_test_vectors.tokenizer.name_or_path
|
||||
if model_id != 'meta-llama/Llama-3.2-1B':
|
||||
raise AssertionError("Test requires meta-llama/Llama-3.2-1B but "
|
||||
f"{model_id} is in use.")
|
||||
if model_id != "meta-llama/Llama-3.2-1B":
|
||||
raise AssertionError(
|
||||
f"Test requires meta-llama/Llama-3.2-1B but {model_id} is in use."
|
||||
)
|
||||
do_logprobs = num_sample_logprobs is not None
|
||||
# EOS under test; if False, stop_token_ids under test
|
||||
is_eos_test = stop_token_type == "eos_token_id"
|
||||
@@ -536,18 +589,16 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
) # '<|end_of_text|>'
|
||||
stop_token_ids = [128009] if not is_eos_test else None # '<|eot_id|>'
|
||||
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=False)
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||
# Dummy engine core outputs, with control tokens suffixed to test stops
|
||||
suffix_token = ([eos_token_id] if is_eos_test else stop_token_ids)
|
||||
suffix_token = [eos_token_id] if is_eos_test else stop_token_ids
|
||||
assert suffix_token is not None and isinstance(suffix_token[0], int)
|
||||
generation_string = dummy_test_vectors.generation_strings[0]
|
||||
generation_tokens = (dummy_test_vectors.generation_tokens[0] +
|
||||
2 * suffix_token)
|
||||
generation_tokens = dummy_test_vectors.generation_tokens[0] + 2 * suffix_token
|
||||
if do_logprobs:
|
||||
generation_logprobs = (
|
||||
dummy_test_vectors.generation_logprobs[0] +
|
||||
2 * [dummy_test_vectors.generation_logprobs[0][-1]])
|
||||
generation_logprobs = dummy_test_vectors.generation_logprobs[0] + 2 * [
|
||||
dummy_test_vectors.generation_logprobs[0][-1]
|
||||
]
|
||||
prompt_string = dummy_test_vectors.prompt_strings[0]
|
||||
prompt_tokens = dummy_test_vectors.prompt_tokens[0]
|
||||
engine_core = MockEngineCore(
|
||||
@@ -556,7 +607,8 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
prompt_logprobs_raw=None,
|
||||
eos_token_id=eos_token_id,
|
||||
stop_token_ids=stop_token_ids,
|
||||
ignore_eos=ignore_eos)
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
|
||||
# Make request.
|
||||
request_id = "request-0"
|
||||
@@ -580,7 +632,8 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
prompt_logprobs=None,
|
||||
ignore_eos=ignore_eos,
|
||||
),
|
||||
pooling_params=None)
|
||||
pooling_params=None,
|
||||
)
|
||||
|
||||
# Add request to the detokenizer.
|
||||
output_processor.add_request(request, prompt_string)
|
||||
@@ -605,7 +658,7 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
# Update tracking.
|
||||
request_output = request_outputs[0]
|
||||
if request_output.finished:
|
||||
finish_reason = ("length" if is_eos_ignore_test else "stop")
|
||||
finish_reason = "length" if is_eos_ignore_test else "stop"
|
||||
assert request_output.outputs[0].finish_reason == finish_reason
|
||||
|
||||
gen_string += request_output.outputs[0].text
|
||||
@@ -614,7 +667,7 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
gen_logprobs.extend(request_output.outputs[0].logprobs)
|
||||
|
||||
# Validate generated text
|
||||
control_token = '<|end_of_text|>' if is_eos_test else '<|eot_id|>'
|
||||
control_token = "<|end_of_text|>" if is_eos_test else "<|eot_id|>"
|
||||
if is_eos_ignore_test:
|
||||
# Length-based stop; expect full string
|
||||
ref_str = generation_string + 2 * control_token
|
||||
@@ -624,14 +677,15 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
else:
|
||||
# Stop token triggered but not in output
|
||||
ref_str = generation_string
|
||||
assert gen_string == ref_str, (f"{gen_string=}, {ref_str=}")
|
||||
assert gen_string == ref_str, f"{gen_string=}, {ref_str=}"
|
||||
|
||||
if do_logprobs:
|
||||
# Validate number of sample logprobs
|
||||
num_tokens = len(gen_tokens)
|
||||
num_logprobs = len(gen_logprobs)
|
||||
assert num_tokens == num_logprobs, (
|
||||
f"Token count ({num_tokens}) != logprobs count ({num_logprobs})")
|
||||
f"Token count ({num_tokens}) != logprobs count ({num_logprobs})"
|
||||
)
|
||||
|
||||
# Check requests are finished
|
||||
assert output_processor.get_num_unfinished_requests() == 0
|
||||
@@ -639,22 +693,24 @@ def test_stop_token(include_stop_str_in_output: bool,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
||||
@pytest.mark.parametrize("num_sample_logprobs",
|
||||
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
||||
def test_stop_string(include_stop_str_in_output: bool,
|
||||
num_sample_logprobs: Optional[int], dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=False)
|
||||
@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
||||
def test_stop_string(
|
||||
include_stop_str_in_output: bool,
|
||||
num_sample_logprobs: Optional[int],
|
||||
dummy_test_vectors,
|
||||
):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
|
||||
if num_sample_logprobs else None,
|
||||
prompt_logprobs_raw=None)
|
||||
if num_sample_logprobs
|
||||
else None,
|
||||
prompt_logprobs_raw=None,
|
||||
)
|
||||
|
||||
# Make N requests.
|
||||
request_id_list = [
|
||||
f"request-{idx}"
|
||||
for idx in range(len(dummy_test_vectors.prompt_strings))
|
||||
f"request-{idx}" for idx in range(len(dummy_test_vectors.prompt_strings))
|
||||
]
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
@@ -675,7 +731,8 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
logprobs=num_sample_logprobs,
|
||||
prompt_logprobs=None,
|
||||
),
|
||||
pooling_params=None)
|
||||
pooling_params=None,
|
||||
)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
@@ -715,7 +772,8 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
prompt_logprobs = request_output.prompt_logprobs
|
||||
logprobs = request_output.outputs[0].logprobs
|
||||
gen_cumulative_logprobs[request_id] = request_output.outputs[
|
||||
0].cumulative_logprob
|
||||
0
|
||||
].cumulative_logprob
|
||||
if request_id not in gen_strings:
|
||||
gen_strings[request_id] = new_text
|
||||
gen_tokens[request_id] = new_tokens
|
||||
@@ -733,8 +791,8 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
|
||||
# Confirmed tracked values matches what we expected.
|
||||
for idx, (ref_gen_str, stop_str) in enumerate(
|
||||
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)):
|
||||
|
||||
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)
|
||||
):
|
||||
# Request should be aborted.
|
||||
request_id = f"request-{idx}"
|
||||
assert request_id in aborted
|
||||
@@ -748,24 +806,28 @@ def test_stop_string(include_stop_str_in_output: bool,
|
||||
ref_str_inc_stop = ref_gen_str[:stop_str_idx] + stop_str
|
||||
|
||||
if include_stop_str_in_output:
|
||||
assert gen_str == ref_str_inc_stop, (
|
||||
f"{gen_str=}, {ref_str_inc_stop=}")
|
||||
assert gen_str == ref_str_inc_stop, f"{gen_str=}, {ref_str_inc_stop=}"
|
||||
else:
|
||||
assert gen_str == ref_str_exc_stop, (
|
||||
f"{gen_str=}, {ref_str_exc_stop=}")
|
||||
assert gen_str == ref_str_exc_stop, f"{gen_str=}, {ref_str_exc_stop=}"
|
||||
|
||||
# Confirmed tracked logprobs match what we expect
|
||||
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
|
||||
gen_cumulative_logprobs, dummy_test_vectors,
|
||||
request_id_list, num_sample_logprobs, None)
|
||||
_validate_logprobs(
|
||||
gen_tokens,
|
||||
gen_logprobs,
|
||||
gen_prompt_logprobs,
|
||||
gen_cumulative_logprobs,
|
||||
dummy_test_vectors,
|
||||
request_id_list,
|
||||
num_sample_logprobs,
|
||||
None,
|
||||
)
|
||||
|
||||
assert output_processor.get_num_unfinished_requests() == 0
|
||||
assert not output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
def test_iteration_stats(dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=True)
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
|
||||
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
||||
engine_core_timestamp = time.monotonic()
|
||||
|
||||
@@ -782,7 +844,8 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=None,
|
||||
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
# Add all requests except one to the OutputProcessor.
|
||||
@@ -794,12 +857,13 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
# First iteration has 2 prefills.
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
total_prompt_tokens = sum([
|
||||
len(prompt_tokens)
|
||||
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
|
||||
])
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
|
||||
total_prompt_tokens = sum(
|
||||
[
|
||||
len(prompt_tokens)
|
||||
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
|
||||
]
|
||||
)
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
@@ -807,8 +871,7 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
# Just decodes in this step.
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == 0
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
@@ -818,8 +881,7 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
num_active += 1
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
|
||||
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
|
||||
@@ -828,8 +890,7 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
# Just decodes in this step.
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == 0
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
@@ -853,16 +914,13 @@ async def test_request_output_collector():
|
||||
text=TEXT,
|
||||
token_ids=[idx],
|
||||
cumulative_logprob=(idx + 1 * 1.0),
|
||||
logprobs=[{
|
||||
"a": idx,
|
||||
"b": idx
|
||||
}],
|
||||
finish_reason="length" if
|
||||
(idx == NUM_REQS - 1) else None,
|
||||
logprobs=[{"a": idx, "b": idx}],
|
||||
finish_reason="length" if (idx == NUM_REQS - 1) else None,
|
||||
)
|
||||
],
|
||||
finished=(idx == NUM_REQS - 1),
|
||||
) for idx in range(NUM_REQS)
|
||||
)
|
||||
for idx in range(NUM_REQS)
|
||||
]
|
||||
|
||||
collector = RequestOutputCollector(RequestOutputKind.DELTA)
|
||||
@@ -888,8 +946,7 @@ async def test_request_output_collector():
|
||||
assert not output.finished
|
||||
# Text, token_ids, and logprobs should get merged.
|
||||
assert output.outputs[0].text == TEXT * num_to_put
|
||||
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
|
||||
list(range(num_to_put))):
|
||||
for tok_0, tok_1 in zip(output.outputs[0].token_ids, list(range(num_to_put))):
|
||||
assert tok_0 == tok_1
|
||||
assert len(output.outputs[0].logprobs) == num_to_put
|
||||
|
||||
@@ -910,8 +967,7 @@ async def test_request_output_collector():
|
||||
assert output.outputs[0].finish_reason == "length"
|
||||
# Text, token_ids, and logprobs should get merged.
|
||||
assert output.outputs[0].text == TEXT * num_to_put
|
||||
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
|
||||
list(range(num_to_put))):
|
||||
for tok_0, tok_1 in zip(output.outputs[0].token_ids, list(range(num_to_put))):
|
||||
assert tok_0 == tok_1
|
||||
assert len(output.outputs[0].logprobs) == num_to_put
|
||||
|
||||
@@ -1003,8 +1059,7 @@ async def test_cumulative_output_collector_n():
|
||||
|
||||
@pytest.mark.parametrize("runner", ["generate", "pooling"])
|
||||
def test_abort_requests(runner: str, dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer,
|
||||
log_stats=True)
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
@@ -1016,9 +1071,9 @@ def test_abort_requests(runner: str, dummy_test_vectors):
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
sampling_params=SamplingParams() if runner == "generate" else None,
|
||||
pooling_params=PoolingParams(
|
||||
task="embed") if runner == "pooling" else None,
|
||||
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
pooling_params=PoolingParams(task="embed") if runner == "pooling" else None,
|
||||
)
|
||||
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
|
||||
]
|
||||
|
||||
for request in requests:
|
||||
|
||||
Reference in New Issue
Block a user