Change the name to vLLM (#150)
This commit is contained in:
@@ -5,12 +5,13 @@ import random
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
from cacheflow import LLM, SamplingParams
|
||||
import torch
|
||||
from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM,
|
||||
PreTrainedTokenizerBase)
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
@@ -70,7 +71,7 @@ def sample_requests(
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def run_cacheflow(
|
||||
def run_vllm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tensor_parallel_size: int,
|
||||
@@ -172,8 +173,8 @@ def main(args: argparse.Namespace):
|
||||
tokenizer = get_tokenizer(args.model)
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
if args.backend == "cacheflow":
|
||||
elapsed_time = run_cacheflow(
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
requests, args.model, args.tensor_parallel_size, args.seed, args.n,
|
||||
args.use_beam_search)
|
||||
elif args.backend == "hf":
|
||||
@@ -192,8 +193,8 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend", type=str, choices=["cacheflow", "hf"],
|
||||
default="cacheflow")
|
||||
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
|
||||
default="vllm")
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||
@@ -207,7 +208,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--hf-max-batch-size", type=int, default=None,
|
||||
help="Maximum batch size for HF backend.")
|
||||
args = parser.parse_args()
|
||||
if args.backend == "cacheflow":
|
||||
if args.backend == "vllm":
|
||||
if args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
elif args.backend == "hf":
|
||||
|
||||
Reference in New Issue
Block a user