diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index 3c0fea8e0..ad6f44404 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -38,6 +38,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput +from vllm.platforms import current_platform from vllm.sampling_params import BeamSearchParams from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.utils.async_utils import merge_async_iterators @@ -256,17 +257,21 @@ def run_hf( max_batch_size: int, trust_remote_code: bool, disable_detokenize: bool = False, + dtype: torch.dtype | None = torch.float16, + enable_torch_compile: bool = False, ) -> float: assert isinstance(tokenizer, PreTrainedTokenizerBase), ( "the hf backend only supports HF tokenizers" ) llm = AutoModelForCausalLM.from_pretrained( - model, dtype=torch.float16, trust_remote_code=trust_remote_code + model, dtype=dtype, trust_remote_code=trust_remote_code ) if llm.config.model_type == "llama": # To enable padding in the HF backend. tokenizer.pad_token = tokenizer.eos_token - llm = llm.cuda() + llm = llm.to(current_platform.device_type) + if enable_torch_compile: + llm = torch.compile(llm) pbar = tqdm(total=len(requests)) start = time.perf_counter() @@ -295,7 +300,7 @@ def run_hf( # Generate the sequences. input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids llm_outputs = llm.generate( - input_ids=input_ids.cuda(), + input_ids=input_ids.to(current_platform.device_type), do_sample=True, num_return_sequences=n, temperature=1.0, @@ -733,6 +738,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=None, help="Maximum batch size for HF backend.", ) + parser.add_argument( + "--hf-enable-torch-compile", + action="store_true", + default=False, + help="Enable Torch compile for HF backend.", + ) parser.add_argument( "--output-json", type=str, @@ -884,6 +895,8 @@ def main(args: argparse.Namespace): args.hf_max_batch_size, args.trust_remote_code, args.disable_detokenize, + dtype=args.dtype, + enable_torch_compile=args.hf_enable_torch_compile, ) elif args.backend == "vllm-chat": elapsed_time, request_outputs = run_vllm_chat(