[mypy] Enable following imports for entrypoints (#7248)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Fei <dfdfcai4@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# imports for guided decoding tests
|
||||
import json
|
||||
import re
|
||||
from typing import List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
@@ -174,6 +174,88 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, prompt_logprobs",
|
||||
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
|
||||
)
|
||||
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
prompt_logprobs: Optional[int]):
|
||||
params: Dict = {
|
||||
"messages": [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}],
|
||||
"model":
|
||||
model_name
|
||||
}
|
||||
|
||||
if prompt_logprobs is not None:
|
||||
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||
|
||||
if prompt_logprobs is not None and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.chat.completions.create(**params)
|
||||
else:
|
||||
completion = await client.chat.completions.create(**params)
|
||||
if prompt_logprobs is not None:
|
||||
assert completion.prompt_logprobs is not None
|
||||
assert len(completion.prompt_logprobs) > 0
|
||||
else:
|
||||
assert completion.prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
params: Dict = {
|
||||
"messages": [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}],
|
||||
"model":
|
||||
model_name,
|
||||
"extra_body": {
|
||||
"prompt_logprobs": 1
|
||||
}
|
||||
}
|
||||
|
||||
completion_1 = await client.chat.completions.create(**params)
|
||||
|
||||
params["extra_body"] = {"prompt_logprobs": 2}
|
||||
completion_2 = await client.chat.completions.create(**params)
|
||||
|
||||
assert len(completion_1.prompt_logprobs[3]) == 1
|
||||
assert len(completion_2.prompt_logprobs[3]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
|
||||
Reference in New Issue
Block a user