[Core] Dynamic image size support for VLMs (#5276)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: ywang96 <ywang@roblox.com> Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
@@ -1,11 +1,18 @@
|
||||
from typing import Dict, List, Tuple
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
TokensText = Tuple[List[int], str]
|
||||
|
||||
|
||||
def check_outputs_equal(outputs_0_lst: List[TokensText],
|
||||
outputs_1_lst: List[TokensText], name_0: str,
|
||||
name_1: str):
|
||||
def check_outputs_equal(
|
||||
*,
|
||||
outputs_0_lst: Sequence[TokensText],
|
||||
outputs_1_lst: Sequence[TokensText],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
):
|
||||
"""
|
||||
Compare the two sequences generated by different models,
|
||||
which should be equal.
|
||||
@@ -18,20 +25,28 @@ def check_outputs_equal(outputs_0_lst: List[TokensText],
|
||||
output_ids_0, output_str_0 = outputs_0
|
||||
output_ids_1, output_str_1 = outputs_1
|
||||
|
||||
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
# The text and token outputs should exactly match
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
|
||||
assert output_str_0 == output_str_1, fail_msg
|
||||
assert output_ids_0 == output_ids_1, fail_msg
|
||||
|
||||
|
||||
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
|
||||
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
|
||||
float]],
|
||||
SampleLogprobs]]]
|
||||
|
||||
|
||||
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
|
||||
outputs_1_lst: List[TokensTextLogprobs], name_0: str,
|
||||
name_1: str):
|
||||
def check_logprobs_close(
|
||||
*,
|
||||
outputs_0_lst: Sequence[TokensTextLogprobs],
|
||||
outputs_1_lst: Sequence[TokensTextLogprobs],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
warn_on_mismatch: bool = True,
|
||||
):
|
||||
"""
|
||||
Compare the logprobs of two sequences generated by different models,
|
||||
which should be similar but not necessarily equal.
|
||||
@@ -45,21 +60,52 @@ def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
|
||||
output_ids_0, output_str_0, logprobs_0 = outputs_0
|
||||
output_ids_1, output_str_1, logprobs_1 = outputs_1
|
||||
|
||||
if logprobs_0 is None:
|
||||
logprobs_0 = [None] * len(output_ids_0)
|
||||
if logprobs_1 is None:
|
||||
logprobs_1 = [None] * len(output_ids_1)
|
||||
|
||||
# Loop through generated tokens.
|
||||
for idx, (output_id_0,
|
||||
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||
|
||||
# If generated tokens don't match, then
|
||||
if output_id_0 != output_id_1:
|
||||
logprobs_elem_0 = logprobs_0[idx]
|
||||
logprobs_elem_1 = logprobs_1[idx]
|
||||
|
||||
# Each predicted token must be in top N logprobs of the other
|
||||
assert output_id_0 in logprobs_1[idx], (
|
||||
fail_msg = (
|
||||
f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
assert output_id_1 in logprobs_0[idx], (
|
||||
f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
|
||||
f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
|
||||
|
||||
assert logprobs_elem_0 is not None, fail_msg
|
||||
assert logprobs_elem_1 is not None, fail_msg
|
||||
assert output_id_0 in logprobs_elem_1, fail_msg
|
||||
assert output_id_1 in logprobs_elem_0, fail_msg
|
||||
|
||||
if warn_on_mismatch:
|
||||
with warnings.catch_warnings():
|
||||
# This ensures that repeated warnings are shown
|
||||
# in the output, not just the first occurrence
|
||||
warnings.simplefilter("always")
|
||||
|
||||
warnings.warn(fail_msg, stacklevel=2)
|
||||
|
||||
# Break out since sequences will now diverge.
|
||||
break
|
||||
else:
|
||||
if output_str_0 != output_str_1 and warn_on_mismatch:
|
||||
# The token outputs exactly match,
|
||||
# so the text outputs should exactly match as well
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# This ensures that repeated warnings are shown
|
||||
# in the output, not just the first occurrence
|
||||
warnings.simplefilter("always")
|
||||
|
||||
warnings.warn(fail_msg, stacklevel=2)
|
||||
|
||||
Reference in New Issue
Block a user