[CI/Build] Reuse code for checking output consistency (#5988)
This commit is contained in:
@@ -1,7 +1,43 @@
|
||||
def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1):
|
||||
"""Compare the logprobs of two sequences generated by different models,
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
TokensText = Tuple[List[int], str]
|
||||
|
||||
|
||||
def check_outputs_equal(outputs_0_lst: List[TokensText],
|
||||
outputs_1_lst: List[TokensText], name_0: str,
|
||||
name_1: str):
|
||||
"""
|
||||
Compare the two sequences generated by different models,
|
||||
which should be equal.
|
||||
"""
|
||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||
|
||||
for prompt_idx, (outputs_0,
|
||||
outputs_1) in enumerate(zip(outputs_0_lst,
|
||||
outputs_1_lst)):
|
||||
output_ids_0, output_str_0 = outputs_0
|
||||
output_ids_1, output_str_1 = outputs_1
|
||||
|
||||
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
|
||||
|
||||
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
|
||||
|
||||
|
||||
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
|
||||
outputs_1_lst: List[TokensTextLogprobs], name_0: str,
|
||||
name_1: str):
|
||||
"""
|
||||
Compare the logprobs of two sequences generated by different models,
|
||||
which should be similar but not necessarily equal.
|
||||
"""
|
||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||
|
||||
# Loop through responses to each prompt.
|
||||
for prompt_idx, (outputs_0,
|
||||
outputs_1) in enumerate(zip(outputs_0_lst,
|
||||
|
||||
Reference in New Issue
Block a user