Add script for benchmarking serving throughput (#145)
This commit is contained in:
@@ -10,6 +10,7 @@ def main(args: argparse.Namespace):
|
||||
prompts = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words"
|
||||
for i in range(args.n_threads)]
|
||||
|
||||
api_url = f"http://{args.host}:{args.port}/generate"
|
||||
headers = {"User-Agent": "CacheFlow Benchmark Client"}
|
||||
ploads = [{
|
||||
"prompt": p,
|
||||
@@ -19,8 +20,8 @@ def main(args: argparse.Namespace):
|
||||
} for p in prompts]
|
||||
|
||||
def send_request(results, i):
|
||||
response = requests.post(args.api_url, headers=headers,
|
||||
json=ploads[i], stream=True)
|
||||
response = requests.post(api_url, headers=headers, json=ploads[i],
|
||||
stream=True)
|
||||
results[i] = response
|
||||
|
||||
# use args.n_threads to prompt the backend
|
||||
@@ -50,7 +51,8 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--api-url", type=str, default="http://localhost:8001/generate")
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--max-tokens", type=int, default=128)
|
||||
parser.add_argument("--n-threads", type=int, default=128)
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user