Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -12,12 +12,11 @@ FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
|
||||
# Model-specific expected values
|
||||
EXPECTED_VALUES = {
|
||||
"Qwen/Qwen3-0.6B": 0.41,
|
||||
"deepseek-ai/deepseek-vl2-small": 0.59
|
||||
}
|
||||
EXPECTED_VALUES = {"Qwen/Qwen3-0.6B": 0.41, "deepseek-ai/deepseek-vl2-small": 0.59}
|
||||
|
||||
SIMPLE_PROMPT = "The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means", # noqa: E501
|
||||
SIMPLE_PROMPT = (
|
||||
"The best part about working on vLLM is that I got to meet so many people across various different organizations like UCB, Google, and Meta which means",
|
||||
) # noqa: E501
|
||||
|
||||
# Get model name from environment variable
|
||||
MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B")
|
||||
@@ -25,8 +24,7 @@ MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B")
|
||||
|
||||
def run_simple_prompt():
|
||||
client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL)
|
||||
completion = client.completions.create(model=MODEL_NAME,
|
||||
prompt=SIMPLE_PROMPT)
|
||||
completion = client.completions.create(model=MODEL_NAME, prompt=SIMPLE_PROMPT)
|
||||
|
||||
print("-" * 50)
|
||||
print(f"Completion results for {MODEL_NAME}:")
|
||||
@@ -38,9 +36,11 @@ def test_accuracy():
|
||||
"""Run the end to end accuracy test."""
|
||||
run_simple_prompt()
|
||||
|
||||
model_args = (f"model={MODEL_NAME},"
|
||||
f"base_url={BASE_URL}/completions,"
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
||||
model_args = (
|
||||
f"model={MODEL_NAME},"
|
||||
f"base_url={BASE_URL}/completions,"
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False"
|
||||
)
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="local-completions",
|
||||
@@ -52,11 +52,14 @@ def test_accuracy():
|
||||
expected_value = EXPECTED_VALUES.get(MODEL_NAME)
|
||||
|
||||
if expected_value is None:
|
||||
print(f"Warning: No expected value found for {MODEL_NAME}. "
|
||||
"Skipping accuracy check.")
|
||||
print(
|
||||
f"Warning: No expected value found for {MODEL_NAME}. "
|
||||
"Skipping accuracy check."
|
||||
)
|
||||
print(f"Measured value: {measured_value}")
|
||||
return
|
||||
|
||||
assert (measured_value - RTOL < expected_value
|
||||
and measured_value + RTOL > expected_value
|
||||
), f"Expected: {expected_value} | Measured: {measured_value}"
|
||||
assert (
|
||||
measured_value - RTOL < expected_value
|
||||
and measured_value + RTOL > expected_value
|
||||
), f"Expected: {expected_value} | Measured: {measured_value}"
|
||||
|
||||
Reference in New Issue
Block a user