Stop bench CLI from recursively casting all configs to dict (#37559)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-03-19 14:04:03 +00:00
committed by GitHub
parent 9515c20868
commit 572b432913
8 changed files with 17 additions and 16 deletions

View File

@@ -40,9 +40,9 @@ LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more
details.
"""
import dataclasses
import random
import time
from dataclasses import fields
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
@@ -124,7 +124,7 @@ def main(args):
# Create the LLM engine
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
print("------warm up------")

View File

@@ -32,6 +32,7 @@ import dataclasses
import json
import random
import time
from dataclasses import fields
from transformers import PreTrainedTokenizerBase
@@ -196,7 +197,7 @@ def main(args):
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
sampling_params = SamplingParams(
temperature=0,

View File

@@ -3,10 +3,10 @@
"""Benchmark offline prioritization."""
import argparse
import dataclasses
import json
import random
import time
from dataclasses import fields
from transformers import AutoTokenizer, PreTrainedTokenizerBase
@@ -79,7 +79,7 @@ def run_vllm(
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
assert all(
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])

View File

@@ -3,10 +3,10 @@
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import dataclasses
import json
import os
import time
from dataclasses import fields
from typing import Any
import numpy as np
@@ -85,7 +85,7 @@ def main(args: argparse.Namespace):
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
assert llm.llm_engine.model_config.max_model_len >= (
args.input_len + args.output_len
), (

View File

@@ -14,10 +14,10 @@ Run:
"""
import argparse
import dataclasses
import json
import time
from collections import defaultdict
from dataclasses import fields
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal
@@ -225,7 +225,7 @@ def benchmark_multimodal_processor(
args.seed = 0
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
tokenizer = llm.get_tokenizer()
requests = get_requests(args, tokenizer)

View File

@@ -9,7 +9,6 @@ and cache operations) for both cold and warm scenarios:
"""
import argparse
import dataclasses
import json
import multiprocessing
import os
@@ -17,6 +16,7 @@ import shutil
import tempfile
import time
from contextlib import contextmanager
from dataclasses import fields
from typing import Any
import numpy as np
@@ -67,7 +67,7 @@ def run_startup_in_subprocess(engine_args, result_queue):
# Measure total startup time
start_time = time.perf_counter()
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
total_startup_time = time.perf_counter() - start_time

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import math
from dataclasses import asdict, dataclass
from dataclasses import dataclass, fields
from pathlib import Path
from typing import ClassVar, Literal, get_args
@@ -267,7 +267,7 @@ class SweepServeWorkloadArgs(SweepServeArgs):
base_args = SweepServeArgs.from_cli_args(args)
return cls(
**asdict(base_args),
**{f.name: getattr(base_args, f.name) for f in fields(base_args)},
workload_var=args.workload_var,
workload_iters=args.workload_iters,
)

View File

@@ -3,12 +3,12 @@
"""Benchmark offline inference throughput."""
import argparse
import dataclasses
import json
import os
import random
import time
import warnings
from dataclasses import fields
from typing import Any
import torch
@@ -53,7 +53,7 @@ def run_vllm(
) -> tuple[float, list[RequestOutput] | None]:
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
assert all(
llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
@@ -141,7 +141,7 @@ def run_vllm_chat(
"""
from vllm import LLM, SamplingParams
llm = LLM(**dataclasses.asdict(engine_args))
llm = LLM(**{f.name: getattr(engine_args, f.name) for f in fields(engine_args)})
assert all(
llm.llm_engine.model_config.max_model_len