[mypy] Enable following imports for entrypoints (#7248)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Fei <dfdfcai4@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@ import json
|
||||
import re
|
||||
import shutil
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
@@ -268,92 +268,6 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
||||
assert len(completion.choices[0].text) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, prompt_logprobs",
|
||||
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
|
||||
)
|
||||
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||
model_name: str, prompt_logprobs: int):
|
||||
params: Dict = {
|
||||
"messages": [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}],
|
||||
"model":
|
||||
model_name
|
||||
}
|
||||
|
||||
if prompt_logprobs is not None:
|
||||
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||
|
||||
if prompt_logprobs and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError) as err_info:
|
||||
await client.chat.completions.create(**params)
|
||||
expected_err_string = (
|
||||
"Error code: 400 - {'object': 'error', 'message': "
|
||||
"'Prompt_logprobs set to invalid negative value: -1',"
|
||||
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
|
||||
assert str(err_info.value) == expected_err_string
|
||||
else:
|
||||
completion = await client.chat.completions.create(**params)
|
||||
if prompt_logprobs and prompt_logprobs > 0:
|
||||
assert completion.prompt_logprobs is not None
|
||||
assert len(completion.prompt_logprobs) > 0
|
||||
else:
|
||||
assert completion.prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
params: Dict = {
|
||||
"messages": [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}],
|
||||
"model":
|
||||
model_name,
|
||||
"extra_body": {
|
||||
"prompt_logprobs": 1
|
||||
}
|
||||
}
|
||||
|
||||
completion_1 = await client.chat.completions.create(**params)
|
||||
|
||||
params["extra_body"] = {"prompt_logprobs": 2}
|
||||
completion_2 = await client.chat.completions.create(**params)
|
||||
|
||||
assert len(completion_1.prompt_logprobs[3]) == 1
|
||||
assert len(completion_2.prompt_logprobs[3]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
|
||||
(MODEL_NAME, 0),
|
||||
@@ -361,7 +275,7 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||
(MODEL_NAME, None)])
|
||||
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
prompt_logprobs: int):
|
||||
prompt_logprobs: Optional[int]):
|
||||
params: Dict = {
|
||||
"prompt": ["A robot may not injure another robot", "My name is"],
|
||||
"model": model_name,
|
||||
@@ -369,17 +283,12 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
||||
if prompt_logprobs is not None:
|
||||
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||
|
||||
if prompt_logprobs and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError) as err_info:
|
||||
if prompt_logprobs is not None and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(**params)
|
||||
expected_err_string = (
|
||||
"Error code: 400 - {'object': 'error', 'message': "
|
||||
"'Prompt_logprobs set to invalid negative value: -1',"
|
||||
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
|
||||
assert str(err_info.value) == expected_err_string
|
||||
else:
|
||||
completion = await client.completions.create(**params)
|
||||
if prompt_logprobs and prompt_logprobs > 0:
|
||||
if prompt_logprobs is not None:
|
||||
assert completion.choices[0].prompt_logprobs is not None
|
||||
assert len(completion.choices[0].prompt_logprobs) > 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user