Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -50,16 +50,13 @@ async def client(server):
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_single_completion(client: openai.AsyncOpenAI,
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str) -> None:
|
||||
|
||||
async def test_single_completion(
|
||||
client: openai.AsyncOpenAI, server: RemoteOpenAIServer, model_name: str
|
||||
) -> None:
|
||||
async def make_request():
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=10,
|
||||
temperature=1.0)
|
||||
model=model_name, prompt="Hello, my name is", max_tokens=10, temperature=1.0
|
||||
)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
@@ -108,9 +105,9 @@ async def test_single_completion(client: openai.AsyncOpenAI,
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str) -> None:
|
||||
async def test_completion_streaming(
|
||||
client: openai.AsyncOpenAI, server: RemoteOpenAIServer, model_name: str
|
||||
) -> None:
|
||||
prompt = "What is an LLM?"
|
||||
|
||||
async def make_streaming_request():
|
||||
@@ -124,11 +121,9 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
single_output = single_completion.choices[0].text
|
||||
|
||||
# Perform the streaming request
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True)
|
||||
stream = await client.completions.create(
|
||||
model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, stream=True
|
||||
)
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
last_chunk = None
|
||||
@@ -139,16 +134,15 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
last_chunk = chunk # Keep track of the last chunk
|
||||
|
||||
# finish reason should only return in the last block for OpenAI API
|
||||
assert finish_reason_count == 1, (
|
||||
"Finish reason should appear exactly once.")
|
||||
assert last_chunk is not None, (
|
||||
"Stream should have yielded at least one chunk.")
|
||||
assert last_chunk.choices[
|
||||
0].finish_reason == "length", "Finish reason should be 'length'."
|
||||
assert finish_reason_count == 1, "Finish reason should appear exactly once."
|
||||
assert last_chunk is not None, "Stream should have yielded at least one chunk."
|
||||
assert last_chunk.choices[0].finish_reason == "length", (
|
||||
"Finish reason should be 'length'."
|
||||
)
|
||||
# Check that the combined text matches the non-streamed version.
|
||||
assert "".join(
|
||||
chunks
|
||||
) == single_output, "Streamed output should match non-streamed output."
|
||||
assert "".join(chunks) == single_output, (
|
||||
"Streamed output should match non-streamed output."
|
||||
)
|
||||
return True # Indicate success for this request
|
||||
|
||||
# Test single request
|
||||
@@ -162,9 +156,9 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
tasks = [make_streaming_request() for _ in range(num_requests)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(
|
||||
results
|
||||
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
|
||||
assert len(results) == num_requests, (
|
||||
f"Expected {num_requests} results, got {len(results)}"
|
||||
)
|
||||
assert all(results), "Not all streaming requests completed successfully."
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
@@ -172,9 +166,9 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
tasks = [make_streaming_request() for _ in range(num_requests)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(
|
||||
results
|
||||
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
|
||||
assert len(results) == num_requests, (
|
||||
f"Expected {num_requests} results, got {len(results)}"
|
||||
)
|
||||
assert all(results), "Not all streaming requests completed successfully."
|
||||
|
||||
# Check request balancing via Prometheus metrics if DP_SIZE > 1
|
||||
|
||||
Reference in New Issue
Block a user