[Tests] Add GSM8k check to SpecDec E2E tests (#34772)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-02-25 07:51:14 -05:00
committed by GitHub
parent 709eadbb0b
commit ee59a7c615
2 changed files with 239 additions and 145 deletions

View File

@@ -110,6 +110,65 @@ async def call_vllm_api(
return "", 0
def _build_gsm8k_prompts(
num_questions: int = 1319,
num_shots: int = 5,
) -> tuple[list[str], list[int]]:
"""Build few-shot GSM8K completion prompts and ground-truth labels."""
if num_questions == 0:
return [], []
train_data, test_data = load_gsm8k_data()
num_questions = min(num_questions, len(test_data))
few_shot_examples = ""
for i in range(num_shots):
few_shot_examples += (
f"Question: {train_data[i]['question']}\n"
f"Answer: {train_data[i]['answer']}\n\n"
)
prompts = []
labels = []
for i in range(num_questions):
prompts.append(
few_shot_examples + f"Question: {test_data[i]['question']}\nAnswer:"
)
labels.append(get_answer_value(test_data[i]["answer"]))
assert all(label != INVALID for label in labels), "Some labels are invalid"
return prompts, labels
def _score_gsm8k(
states: list[str],
output_tokens: list[int],
labels: list[int],
num_shots: int,
max_tokens: int,
latency: float,
) -> dict[str, float | int]:
"""Score GSM8K responses and return a results dict."""
num_questions = len(labels)
preds = [get_answer_value(state) for state in states]
accuracy = np.mean(np.array(preds) == np.array(labels))
invalid_rate = np.mean(np.array(preds) == INVALID)
total_output_tokens = sum(output_tokens)
tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0
return {
"accuracy": accuracy,
"invalid_rate": invalid_rate,
"latency": latency,
"questions_per_second": num_questions / latency if latency > 0 else 0.0,
"total_output_tokens": total_output_tokens,
"tokens_per_second": tokens_per_second,
"num_questions": num_questions,
"num_shots": num_shots,
"max_tokens": max_tokens,
"timestamp": time.time(),
}
def evaluate_gsm8k(
num_questions: int = 1319,
num_shots: int = 5,
@@ -125,40 +184,17 @@ def evaluate_gsm8k(
Returns dict with accuracy, invalid_rate, latency, etc.
"""
base_url = f"{host}:{port}"
prompts, labels = _build_gsm8k_prompts(num_questions, num_shots)
num_questions = len(prompts)
# Load GSM8K train and test data
train_data, test_data = load_gsm8k_data()
# Limit to available test questions
num_questions = min(num_questions, len(test_data))
# Build few-shot examples from train split (like lm-eval does)
few_shot_examples = ""
for i in range(num_shots):
few_shot_examples += (
f"Question: {train_data[i]['question']}\n"
f"Answer: {train_data[i]['answer']}\n\n"
)
# Prepare test questions and labels from test split
questions = []
labels = []
for i in range(num_questions):
questions.append(f"Question: {test_data[i]['question']}\nAnswer:")
labels.append(get_answer_value(test_data[i]["answer"]))
assert all(label != INVALID for label in labels), "Some labels are invalid"
# Run evaluation
async def run_async_evaluation():
states: list[str] = [""] * num_questions
output_tokens: list[int] = [0] * num_questions
async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]:
prompt = few_shot_examples + questions[i]
answer, tokens = await call_vllm_api(
session=session,
prompt=prompt,
prompt=prompts[i],
temperature=temperature,
max_tokens=max_tokens,
stop=["Question", "Assistant:", "<|separator|>"],
@@ -183,27 +219,43 @@ def evaluate_gsm8k(
states, output_tokens = asyncio.run(run_async_evaluation())
latency = time.perf_counter() - tic
# Compute metrics
preds = [get_answer_value(state) for state in states]
accuracy = np.mean(np.array(preds) == np.array(labels))
invalid_rate = np.mean(np.array(preds) == INVALID)
total_output_tokens = sum(output_tokens)
tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0
return _score_gsm8k(states, output_tokens, labels, num_shots, max_tokens, latency)
result = {
"accuracy": accuracy,
"invalid_rate": invalid_rate,
"latency": latency,
"questions_per_second": num_questions / latency,
"total_output_tokens": total_output_tokens,
"tokens_per_second": tokens_per_second,
"num_questions": num_questions,
"num_shots": num_shots,
"max_tokens": max_tokens,
"timestamp": time.time(),
}
return result
def evaluate_gsm8k_offline(
llm,
num_questions: int = 1319,
num_shots: int = 5,
max_tokens: int = 256,
temperature: float = 0.0,
) -> dict[str, float | int]:
"""Evaluate GSM8K accuracy using an offline vllm.LLM object.
Same prompts and scoring as evaluate_gsm8k(), but runs generation
directly via llm.generate() instead of calling a server over HTTP.
"""
from vllm import SamplingParams
prompts, labels = _build_gsm8k_prompts(num_questions, num_shots)
sampling_params = SamplingParams(
temperature=temperature,
max_tokens=max_tokens,
stop=["Question", "Assistant:", "<|separator|>"],
)
print(
f"Running offline GSM8K evaluation: {len(prompts)} questions, {num_shots}-shot"
)
tic = time.perf_counter()
outputs = llm.generate(prompts, sampling_params)
latency = time.perf_counter() - tic
states = [o.outputs[0].text for o in outputs]
output_tokens = [len(o.outputs[0].token_ids) for o in outputs]
return _score_gsm8k(states, output_tokens, labels, num_shots, max_tokens, latency)
def main() -> None: