Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -76,13 +76,15 @@ def get_answer_value(answer_str: str) -> int:
|
||||
return INVALID
|
||||
|
||||
|
||||
async def call_vllm_api(session: aiohttp.ClientSession,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stop: Optional[list[str]] = None,
|
||||
url: Optional[str] = None,
|
||||
seed: Optional[int] = None) -> str:
|
||||
async def call_vllm_api(
|
||||
session: aiohttp.ClientSession,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stop: Optional[list[str]] = None,
|
||||
url: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Call vLLM's OpenAI-compatible completions endpoint."""
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
@@ -94,8 +96,7 @@ async def call_vllm_api(session: aiohttp.ClientSession,
|
||||
data["seed"] = seed
|
||||
|
||||
try:
|
||||
async with session.post(f"{url}/v1/completions",
|
||||
json=data) as response:
|
||||
async with session.post(f"{url}/v1/completions", json=data) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
return result["choices"][0]["text"]
|
||||
@@ -104,16 +105,18 @@ async def call_vllm_api(session: aiohttp.ClientSession,
|
||||
return ""
|
||||
|
||||
|
||||
def evaluate_gsm8k(num_questions: int = 1319,
|
||||
num_shots: int = 5,
|
||||
max_tokens: int = 256,
|
||||
host: str = "http://127.0.0.1",
|
||||
port: int = 8000,
|
||||
temperature: float = 0.0,
|
||||
seed: Optional[int] = 42) -> dict[str, Union[float, int]]:
|
||||
def evaluate_gsm8k(
|
||||
num_questions: int = 1319,
|
||||
num_shots: int = 5,
|
||||
max_tokens: int = 256,
|
||||
host: str = "http://127.0.0.1",
|
||||
port: int = 8000,
|
||||
temperature: float = 0.0,
|
||||
seed: Optional[int] = 42,
|
||||
) -> dict[str, Union[float, int]]:
|
||||
"""
|
||||
Evaluate GSM8K accuracy using vLLM serve endpoint.
|
||||
|
||||
|
||||
Returns dict with accuracy, invalid_rate, latency, etc.
|
||||
"""
|
||||
base_url = f"{host}:{port}"
|
||||
@@ -127,8 +130,10 @@ def evaluate_gsm8k(num_questions: int = 1319,
|
||||
# Build few-shot examples from train split (like lm-eval does)
|
||||
few_shot_examples = ""
|
||||
for i in range(num_shots):
|
||||
few_shot_examples += (f"Question: {train_data[i]['question']}\n"
|
||||
f"Answer: {train_data[i]['answer']}\n\n")
|
||||
few_shot_examples += (
|
||||
f"Question: {train_data[i]['question']}\n"
|
||||
f"Answer: {train_data[i]['answer']}\n\n"
|
||||
)
|
||||
|
||||
# Prepare test questions and labels from test split
|
||||
questions = []
|
||||
@@ -157,15 +162,15 @@ def evaluate_gsm8k(num_questions: int = 1319,
|
||||
states[i] = answer
|
||||
return answer
|
||||
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
|
||||
total=600)) as session:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=600)
|
||||
) as session:
|
||||
tasks = [get_answer(session, i) for i in range(num_questions)]
|
||||
await tqdm.gather(*tasks, desc="Evaluating")
|
||||
|
||||
return states
|
||||
|
||||
print(f"Running GSM8K evaluation: {num_questions} questions, "
|
||||
f"{num_shots}-shot")
|
||||
print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot")
|
||||
|
||||
tic = time.perf_counter()
|
||||
states = asyncio.run(run_async_evaluation())
|
||||
@@ -191,36 +196,28 @@ def evaluate_gsm8k(num_questions: int = 1319,
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GSM8K evaluation for vLLM serve")
|
||||
parser.add_argument("--num-shots",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of few-shot examples")
|
||||
parser.add_argument("--num-questions",
|
||||
type=int,
|
||||
default=1319,
|
||||
help="Number of questions to evaluate")
|
||||
parser.add_argument("--max-tokens",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Max tokens for generation")
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="http://127.0.0.1",
|
||||
help="Host URL")
|
||||
parser = argparse.ArgumentParser(description="GSM8K evaluation for vLLM serve")
|
||||
parser.add_argument(
|
||||
"--num-shots", type=int, default=5, help="Number of few-shot examples"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-questions",
|
||||
type=int,
|
||||
default=1319,
|
||||
help="Number of questions to evaluate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens", type=int, default=256, help="Max tokens for generation"
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1", help="Host URL")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port number")
|
||||
parser.add_argument("--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Temperature for generation")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Random seed for reproducibility")
|
||||
parser.add_argument("--save-results",
|
||||
type=str,
|
||||
help="Save results to JSON file")
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=0.0, help="Temperature for generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for reproducibility"
|
||||
)
|
||||
parser.add_argument("--save-results", type=str, help="Save results to JSON file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user