Add script for benchmarking serving throughput (#145)

This commit is contained in:
Woosuk Kwon
2023-06-14 19:55:38 -07:00
committed by GitHub
parent da5ddcd544
commit 311490a720
10 changed files with 421 additions and 415 deletions

View File

@@ -10,6 +10,7 @@ def main(args: argparse.Namespace):
prompts = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words"
for i in range(args.n_threads)]
api_url = f"http://{args.host}:{args.port}/generate"
headers = {"User-Agent": "CacheFlow Benchmark Client"}
ploads = [{
"prompt": p,
@@ -19,8 +20,8 @@ def main(args: argparse.Namespace):
} for p in prompts]
def send_request(results, i):
response = requests.post(args.api_url, headers=headers,
json=ploads[i], stream=True)
response = requests.post(api_url, headers=headers, json=ploads[i],
stream=True)
results[i] = response
# use args.n_threads to prompt the backend
@@ -50,7 +51,8 @@ def main(args: argparse.Namespace):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api-url", type=str, default="http://localhost:8001/generate")
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--max-tokens", type=int, default=128)
parser.add_argument("--n-threads", type=int, default=128)
args = parser.parse_args()