refine vllm bench throughput --backend hf (#35971)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2026-03-07 10:10:33 +08:00
committed by GitHub
parent c7f32e08c2
commit 7eb524e64c

View File

@@ -38,6 +38,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import BeamSearchParams
from vllm.tokenizers import TokenizerLike, get_tokenizer
from vllm.utils.async_utils import merge_async_iterators
@@ -256,17 +257,21 @@ def run_hf(
max_batch_size: int,
trust_remote_code: bool,
disable_detokenize: bool = False,
dtype: torch.dtype | None = torch.float16,
enable_torch_compile: bool = False,
) -> float:
assert isinstance(tokenizer, PreTrainedTokenizerBase), (
"the hf backend only supports HF tokenizers"
)
llm = AutoModelForCausalLM.from_pretrained(
model, dtype=torch.float16, trust_remote_code=trust_remote_code
model, dtype=dtype, trust_remote_code=trust_remote_code
)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda()
llm = llm.to(current_platform.device_type)
if enable_torch_compile:
llm = torch.compile(llm)
pbar = tqdm(total=len(requests))
start = time.perf_counter()
@@ -295,7 +300,7 @@ def run_hf(
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
input_ids=input_ids.to(current_platform.device_type),
do_sample=True,
num_return_sequences=n,
temperature=1.0,
@@ -733,6 +738,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=None,
help="Maximum batch size for HF backend.",
)
parser.add_argument(
"--hf-enable-torch-compile",
action="store_true",
default=False,
help="Enable Torch compile for HF backend.",
)
parser.add_argument(
"--output-json",
type=str,
@@ -884,6 +895,8 @@ def main(args: argparse.Namespace):
args.hf_max_batch_size,
args.trust_remote_code,
args.disable_detokenize,
dtype=args.dtype,
enable_torch_compile=args.hf_enable_torch_compile,
)
elif args.backend == "vllm-chat":
elapsed_time, request_outputs = run_vllm_chat(