[Tests] Add GSM8k check to SpecDec E2E tests (#34772)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
709eadbb0b
commit
ee59a7c615
@@ -110,6 +110,65 @@ async def call_vllm_api(
|
||||
return "", 0
|
||||
|
||||
|
||||
def _build_gsm8k_prompts(
|
||||
num_questions: int = 1319,
|
||||
num_shots: int = 5,
|
||||
) -> tuple[list[str], list[int]]:
|
||||
"""Build few-shot GSM8K completion prompts and ground-truth labels."""
|
||||
if num_questions == 0:
|
||||
return [], []
|
||||
train_data, test_data = load_gsm8k_data()
|
||||
num_questions = min(num_questions, len(test_data))
|
||||
|
||||
few_shot_examples = ""
|
||||
for i in range(num_shots):
|
||||
few_shot_examples += (
|
||||
f"Question: {train_data[i]['question']}\n"
|
||||
f"Answer: {train_data[i]['answer']}\n\n"
|
||||
)
|
||||
|
||||
prompts = []
|
||||
labels = []
|
||||
for i in range(num_questions):
|
||||
prompts.append(
|
||||
few_shot_examples + f"Question: {test_data[i]['question']}\nAnswer:"
|
||||
)
|
||||
labels.append(get_answer_value(test_data[i]["answer"]))
|
||||
|
||||
assert all(label != INVALID for label in labels), "Some labels are invalid"
|
||||
return prompts, labels
|
||||
|
||||
|
||||
def _score_gsm8k(
|
||||
states: list[str],
|
||||
output_tokens: list[int],
|
||||
labels: list[int],
|
||||
num_shots: int,
|
||||
max_tokens: int,
|
||||
latency: float,
|
||||
) -> dict[str, float | int]:
|
||||
"""Score GSM8K responses and return a results dict."""
|
||||
num_questions = len(labels)
|
||||
preds = [get_answer_value(state) for state in states]
|
||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid_rate = np.mean(np.array(preds) == INVALID)
|
||||
total_output_tokens = sum(output_tokens)
|
||||
tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0
|
||||
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"invalid_rate": invalid_rate,
|
||||
"latency": latency,
|
||||
"questions_per_second": num_questions / latency if latency > 0 else 0.0,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"tokens_per_second": tokens_per_second,
|
||||
"num_questions": num_questions,
|
||||
"num_shots": num_shots,
|
||||
"max_tokens": max_tokens,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
|
||||
def evaluate_gsm8k(
|
||||
num_questions: int = 1319,
|
||||
num_shots: int = 5,
|
||||
@@ -125,40 +184,17 @@ def evaluate_gsm8k(
|
||||
Returns dict with accuracy, invalid_rate, latency, etc.
|
||||
"""
|
||||
base_url = f"{host}:{port}"
|
||||
prompts, labels = _build_gsm8k_prompts(num_questions, num_shots)
|
||||
num_questions = len(prompts)
|
||||
|
||||
# Load GSM8K train and test data
|
||||
train_data, test_data = load_gsm8k_data()
|
||||
|
||||
# Limit to available test questions
|
||||
num_questions = min(num_questions, len(test_data))
|
||||
|
||||
# Build few-shot examples from train split (like lm-eval does)
|
||||
few_shot_examples = ""
|
||||
for i in range(num_shots):
|
||||
few_shot_examples += (
|
||||
f"Question: {train_data[i]['question']}\n"
|
||||
f"Answer: {train_data[i]['answer']}\n\n"
|
||||
)
|
||||
|
||||
# Prepare test questions and labels from test split
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(num_questions):
|
||||
questions.append(f"Question: {test_data[i]['question']}\nAnswer:")
|
||||
labels.append(get_answer_value(test_data[i]["answer"]))
|
||||
|
||||
assert all(label != INVALID for label in labels), "Some labels are invalid"
|
||||
|
||||
# Run evaluation
|
||||
async def run_async_evaluation():
|
||||
states: list[str] = [""] * num_questions
|
||||
output_tokens: list[int] = [0] * num_questions
|
||||
|
||||
async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]:
|
||||
prompt = few_shot_examples + questions[i]
|
||||
answer, tokens = await call_vllm_api(
|
||||
session=session,
|
||||
prompt=prompt,
|
||||
prompt=prompts[i],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop=["Question", "Assistant:", "<|separator|>"],
|
||||
@@ -183,27 +219,43 @@ def evaluate_gsm8k(
|
||||
states, output_tokens = asyncio.run(run_async_evaluation())
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
# Compute metrics
|
||||
preds = [get_answer_value(state) for state in states]
|
||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid_rate = np.mean(np.array(preds) == INVALID)
|
||||
total_output_tokens = sum(output_tokens)
|
||||
tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0
|
||||
return _score_gsm8k(states, output_tokens, labels, num_shots, max_tokens, latency)
|
||||
|
||||
result = {
|
||||
"accuracy": accuracy,
|
||||
"invalid_rate": invalid_rate,
|
||||
"latency": latency,
|
||||
"questions_per_second": num_questions / latency,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"tokens_per_second": tokens_per_second,
|
||||
"num_questions": num_questions,
|
||||
"num_shots": num_shots,
|
||||
"max_tokens": max_tokens,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
return result
|
||||
def evaluate_gsm8k_offline(
|
||||
llm,
|
||||
num_questions: int = 1319,
|
||||
num_shots: int = 5,
|
||||
max_tokens: int = 256,
|
||||
temperature: float = 0.0,
|
||||
) -> dict[str, float | int]:
|
||||
"""Evaluate GSM8K accuracy using an offline vllm.LLM object.
|
||||
|
||||
Same prompts and scoring as evaluate_gsm8k(), but runs generation
|
||||
directly via llm.generate() instead of calling a server over HTTP.
|
||||
"""
|
||||
from vllm import SamplingParams
|
||||
|
||||
prompts, labels = _build_gsm8k_prompts(num_questions, num_shots)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop=["Question", "Assistant:", "<|separator|>"],
|
||||
)
|
||||
|
||||
print(
|
||||
f"Running offline GSM8K evaluation: {len(prompts)} questions, {num_shots}-shot"
|
||||
)
|
||||
|
||||
tic = time.perf_counter()
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
states = [o.outputs[0].text for o in outputs]
|
||||
output_tokens = [len(o.outputs[0].token_ids) for o in outputs]
|
||||
|
||||
return _score_gsm8k(states, output_tokens, labels, num_shots, max_tokens, latency)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
Reference in New Issue
Block a user