[Core] *Prompt* logprobs support in Multi-step (#8199)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from vllm.sequence import Logprob, SampleLogprobs
|
||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||
|
||||
TokensText = Tuple[List[int], str]
|
||||
|
||||
@@ -34,20 +34,47 @@ def check_outputs_equal(
|
||||
assert output_ids_0 == output_ids_1, fail_msg
|
||||
|
||||
|
||||
# Representation of generated sequence as a tuple of
|
||||
# * Token ID list
|
||||
# * String
|
||||
# * List of top sample logprobs for each sampled token
|
||||
#
|
||||
# Assumes prompt logprobs were not requested.
|
||||
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
|
||||
float]],
|
||||
SampleLogprobs]]]
|
||||
|
||||
# Allow for tokens to be represented as str's rather than IDs
|
||||
# Allow for tokens to be represented as str's rather than IDs;
|
||||
# tuple of
|
||||
# * Token string representations list
|
||||
# * String
|
||||
# * Optional list of top sample logprobs for each sampled token
|
||||
#
|
||||
# Assumes prompt logprobs were not requested.
|
||||
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
|
||||
List[Dict[str,
|
||||
Logprob]]]]]
|
||||
|
||||
# Representation of generated sequence as a tuple of
|
||||
# * Token ID list
|
||||
# * String
|
||||
# * Optional list of top sample logprobs for each sampled token
|
||||
# * Optional list of top prompt logprobs for each prompt token
|
||||
#
|
||||
# Allows prompt logprobs to be requested.
|
||||
TokensTextLogprobsPromptLogprobs = Tuple[
|
||||
List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
|
||||
Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]
|
||||
|
||||
|
||||
def check_logprobs_close(
|
||||
*,
|
||||
outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
|
||||
outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
|
||||
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
TextTextLogprobs]],
|
||||
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
TextTextLogprobs]],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
num_outputs_0_skip_tokens: int = 0,
|
||||
@@ -57,6 +84,18 @@ def check_logprobs_close(
|
||||
"""Compare the logprobs of two sequences generated by different models,
|
||||
which should be similar but not necessarily equal.
|
||||
|
||||
How sample logprobs are compared:
|
||||
* `always_check_logprobs == True`: set of highest-logprob token ids
|
||||
must match between seq0 and seq1 at all sampled token offsets
|
||||
* `always_check_logprobs == False`: highest-logprob token ids are
|
||||
only compared at sampled token offsets for which generated token
|
||||
ids don't match
|
||||
|
||||
Prompt logprobs must be provided either for both input sequences, or
|
||||
for neither. If prompt logprobs are provided, then highest-logprob
|
||||
prompt token ids must match between seq0 and seq1 at all prompt token
|
||||
offsets.
|
||||
|
||||
Args:
|
||||
outputs_0_lst: First sequence to compare
|
||||
outputs_0_lst: Second sequence to compare
|
||||
@@ -78,8 +117,65 @@ def check_logprobs_close(
|
||||
for prompt_idx, (outputs_0,
|
||||
outputs_1) in enumerate(zip(outputs_0_lst,
|
||||
outputs_1_lst)):
|
||||
output_ids_0, output_str_0, logprobs_0 = outputs_0
|
||||
output_ids_1, output_str_1, logprobs_1 = outputs_1
|
||||
assert len(outputs_0) == len(outputs_1)
|
||||
if len(outputs_0) == 3:
|
||||
assert len(outputs_1) == 3
|
||||
# Break out tokens, text & sample logprobs
|
||||
# (prompt logprobs were not provided)
|
||||
output_ids_0, output_str_0, logprobs_0 = outputs_0
|
||||
output_ids_1, output_str_1, logprobs_1 = outputs_1
|
||||
elif len(outputs_0) == 4:
|
||||
assert len(outputs_1) == 4
|
||||
# Break out tokens, text, sample logprobs & prompt logprobs
|
||||
(
|
||||
output_ids_0,
|
||||
output_str_0,
|
||||
logprobs_0,
|
||||
prompt_logprobs_0,
|
||||
) = outputs_0
|
||||
(
|
||||
output_ids_1,
|
||||
output_str_1,
|
||||
logprobs_1,
|
||||
prompt_logprobs_1,
|
||||
) = outputs_1
|
||||
|
||||
# Test prompt logprobs closeness
|
||||
if (prompt_logprobs_0 is not None
|
||||
and prompt_logprobs_1 is not None):
|
||||
# Both sequences' prompt logprobs lists are not `None``
|
||||
# (although individual list elements may be `None`);
|
||||
# for each token's logprobs:
|
||||
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
|
||||
zip(prompt_logprobs_0, prompt_logprobs_1)):
|
||||
fail_msg = (
|
||||
f"Prompt logprobs test:"
|
||||
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
|
||||
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
|
||||
|
||||
if logprobs_elem_0 is None:
|
||||
# If the seq 0 token's logprobs are `None`,
|
||||
# the seq 1 token's logprobs must be `None`
|
||||
assert logprobs_elem_1 is None, fail_msg
|
||||
else:
|
||||
# If the seq 0 token's logprobs are not `None`,
|
||||
# the seq 1 token's logprobs must not be `None`
|
||||
assert logprobs_elem_1 is not None, fail_msg
|
||||
# Logprobs check: top-k token choices must be the same
|
||||
assert (set(logprobs_elem_0.keys()) == set(
|
||||
logprobs_elem_1.keys())), fail_msg
|
||||
else:
|
||||
# Both sequence logprobs lists must be `None`
|
||||
fail_msg = (f"Prompt logprobs test:"
|
||||
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
|
||||
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
|
||||
|
||||
assert (prompt_logprobs_0 is None
|
||||
and prompt_logprobs_1 is None), fail_msg
|
||||
else:
|
||||
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
|
||||
f"{len(outputs_0)} elements were provided: "
|
||||
f"{outputs_0}")
|
||||
|
||||
if logprobs_0 is None:
|
||||
logprobs_0 = [None] * len(output_ids_0)
|
||||
|
||||
Reference in New Issue
Block a user