Improve parse_raw_prompt test cases for invalid input .v2 (#30512)
Signed-off-by: Kayvan Mivehnejad <K.Mivehnejad@gmail.com>
This commit is contained in:
committed by
GitHub
parent
dc7fb5bebe
commit
29f7d97715
@@ -34,6 +34,13 @@ INPUTS_SLICES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Test that a nested mixed-type list of lists raises a TypeError.
|
||||||
|
@pytest.mark.parametrize("invalid_input", [[[1, 2], ["foo", "bar"]]])
|
||||||
|
def test_invalid_input_raise_type_error(invalid_input):
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
parse_raw_prompts(invalid_input)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_raw_single_batch_empty():
|
def test_parse_raw_single_batch_empty():
|
||||||
with pytest.raises(ValueError, match="at least one prompt"):
|
with pytest.raises(ValueError, match="at least one prompt"):
|
||||||
parse_raw_prompts([])
|
parse_raw_prompts([])
|
||||||
|
|||||||
@@ -33,22 +33,31 @@ def parse_raw_prompts(
|
|||||||
if len(prompt) == 0:
|
if len(prompt) == 0:
|
||||||
raise ValueError("please provide at least one prompt")
|
raise ValueError("please provide at least one prompt")
|
||||||
|
|
||||||
|
# case 2: array of strings
|
||||||
if is_list_of(prompt, str):
|
if is_list_of(prompt, str):
|
||||||
# case 2: array of strings
|
|
||||||
prompt = cast(list[str], prompt)
|
prompt = cast(list[str], prompt)
|
||||||
return [TextPrompt(prompt=elem) for elem in prompt]
|
return [TextPrompt(prompt=elem) for elem in prompt]
|
||||||
|
|
||||||
|
# case 3: array of tokens
|
||||||
if is_list_of(prompt, int):
|
if is_list_of(prompt, int):
|
||||||
# case 3: array of tokens
|
|
||||||
prompt = cast(list[int], prompt)
|
prompt = cast(list[int], prompt)
|
||||||
return [TokensPrompt(prompt_token_ids=prompt)]
|
return [TokensPrompt(prompt_token_ids=prompt)]
|
||||||
if is_list_of(prompt, list):
|
|
||||||
prompt = cast(list[list[int]], prompt)
|
|
||||||
if len(prompt[0]) == 0:
|
|
||||||
raise ValueError("please provide at least one prompt")
|
|
||||||
|
|
||||||
if is_list_of(prompt[0], int):
|
# case 4: array of token arrays
|
||||||
# case 4: array of token arrays
|
if is_list_of(prompt, list):
|
||||||
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
|
first = prompt[0]
|
||||||
|
if not isinstance(first, list):
|
||||||
|
raise ValueError("prompt expected to be a list of lists")
|
||||||
|
|
||||||
|
if len(first) == 0:
|
||||||
|
raise ValueError("Please provide at least one prompt")
|
||||||
|
|
||||||
|
# strict validation: every nested list must be list[int]
|
||||||
|
if not all(is_list_of(elem, int) for elem in prompt):
|
||||||
|
raise TypeError("Nested lists must contain only integers")
|
||||||
|
|
||||||
|
prompt = cast(list[list[int]], prompt)
|
||||||
|
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
|
||||||
|
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"prompt must be a string, array of strings, "
|
"prompt must be a string, array of strings, "
|
||||||
|
|||||||
Reference in New Issue
Block a user