refine vllm bench throughput --backend hf (#35971)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -38,6 +38,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
@@ -256,17 +257,21 @@ def run_hf(
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
disable_detokenize: bool = False,
|
||||
dtype: torch.dtype | None = torch.float16,
|
||||
enable_torch_compile: bool = False,
|
||||
) -> float:
|
||||
assert isinstance(tokenizer, PreTrainedTokenizerBase), (
|
||||
"the hf backend only supports HF tokenizers"
|
||||
)
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, dtype=torch.float16, trust_remote_code=trust_remote_code
|
||||
model, dtype=dtype, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
llm = llm.cuda()
|
||||
llm = llm.to(current_platform.device_type)
|
||||
if enable_torch_compile:
|
||||
llm = torch.compile(llm)
|
||||
|
||||
pbar = tqdm(total=len(requests))
|
||||
start = time.perf_counter()
|
||||
@@ -295,7 +300,7 @@ def run_hf(
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
input_ids=input_ids.to(current_platform.device_type),
|
||||
do_sample=True,
|
||||
num_return_sequences=n,
|
||||
temperature=1.0,
|
||||
@@ -733,6 +738,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-enable-torch-compile",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable Torch compile for HF backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
@@ -884,6 +895,8 @@ def main(args: argparse.Namespace):
|
||||
args.hf_max_batch_size,
|
||||
args.trust_remote_code,
|
||||
args.disable_detokenize,
|
||||
dtype=args.dtype,
|
||||
enable_torch_compile=args.hf_enable_torch_compile,
|
||||
)
|
||||
elif args.backend == "vllm-chat":
|
||||
elapsed_time, request_outputs = run_vllm_chat(
|
||||
|
||||
Reference in New Issue
Block a user