Change the name to vLLM (#150)

This commit is contained in:
Woosuk Kwon
2023-06-17 03:07:40 -07:00
committed by GitHub
parent e5464ee484
commit 0b98ba15c7
90 changed files with 342 additions and 339 deletions

View File

@@ -5,12 +5,13 @@ import random
import time
from typing import List, Tuple
from cacheflow import LLM, SamplingParams
import torch
from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM,
PreTrainedTokenizerBase)
from tqdm import tqdm
from vllm import LLM, SamplingParams
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
config = AutoConfig.from_pretrained(model_name)
@@ -70,7 +71,7 @@ def sample_requests(
return sampled_requests
def run_cacheflow(
def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
tensor_parallel_size: int,
@@ -172,8 +173,8 @@ def main(args: argparse.Namespace):
tokenizer = get_tokenizer(args.model)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "cacheflow":
elapsed_time = run_cacheflow(
if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tensor_parallel_size, args.seed, args.n,
args.use_beam_search)
elif args.backend == "hf":
@@ -192,8 +193,8 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", type=str, choices=["cacheflow", "hf"],
default="cacheflow")
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
default="vllm")
parser.add_argument("--dataset", type=str, required=True,
help="Path to the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
@@ -207,7 +208,7 @@ if __name__ == "__main__":
parser.add_argument("--hf-max-batch-size", type=int, default=None,
help="Maximum batch size for HF backend.")
args = parser.parse_args()
if args.backend == "cacheflow":
if args.backend == "vllm":
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
elif args.backend == "hf":