feat(benchmarks): Add Prefix Caching Benchmark to Serving Benchmark (#3277)

This commit is contained in:
Roger Wang
2024-03-27 13:39:26 -07:00
committed by GitHub
parent 1956931436
commit 45b6ef6513
6 changed files with 899 additions and 157 deletions

View File

@@ -1,8 +1,10 @@
import json
import os
import sys
import time
from dataclasses import dataclass
from typing import Optional
import traceback
from dataclasses import dataclass, field
from typing import List, Optional
import aiohttp
from tqdm.asyncio import tqdm
@@ -26,8 +28,11 @@ class RequestFuncOutput:
generated_text: str = ""
success: bool = False
latency: float = 0
ttft: float = 0
ttft: float = 0 # Time to first token
itl: List[float] = field(
default_factory=list) # List of inter-token latencies
prompt_len: int = 0
error: str = ""
async def async_request_tgi(
@@ -55,71 +60,38 @@ async def async_request_tgi(
ttft = 0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for data in response.content.iter_any():
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
output.latency = time.perf_counter() - st
body = remove_prefix(data.decode("utf-8"), "data:")
output.generated_text = json.loads(body)["generated_text"]
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.success = True
else:
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
output.success = False
if pbar:
pbar.update(1)
return output
async def async_request_vllm(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("generate")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"prompt": request_func_input.prompt,
"n": 1,
"best_of": request_func_input.best_of,
"use_beam_search": request_func_input.use_beam_search,
"temperature": 0.0 if request_func_input.use_beam_search else 1.0,
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"ignore_eos": True,
"stream": True,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0
st = time.perf_counter()
try:
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for data in response.content.iter_any():
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
output.latency = time.perf_counter() - st
# When streaming, '\0' is appended to the end of response.
body = data.decode("utf-8").strip("\0")
output.generated_text = json.loads(
body)["text"][0][len(request_func_input.prompt):]
output.success = True
else:
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
output.generated_text = data["generated_text"]
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
@@ -146,26 +118,45 @@ async def async_request_trt_llm(
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
ttft = 0
ttft = 0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload) as resp:
if resp.status == 200:
async for data in resp.content.iter_any():
async with session.post(url=api_url, json=payload) as response:
if response.status == 200:
async for chunk in response.content:
chunk = chunk.strip()
if not chunk:
continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
data = json.loads(chunk)
timestamp = time.perf_counter()
# First token
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
output.latency = time.perf_counter() - st
body = remove_prefix(data.decode("utf-8"), "data:")
output.generated_text = json.loads(body)["text_output"]
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.generated_text = json.loads(data)["text_output"]
output.success = True
else:
output.error = response.reason
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
@@ -181,35 +172,35 @@ async def async_request_deepspeed_mii(
assert not request_func_input.use_beam_search
payload = {
"prompts": request_func_input.prompt,
"max_new_tokens": request_func_input.output_len,
"ignore_eos": True,
"do_sample": True,
"temperature":
0.01, # deepspeed-mii does not accept 0.0 temperature.
"prompt": request_func_input.prompt,
"max_tokens": request_func_input.output_len,
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
"top_p": 1.0,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
# DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
# NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
# will use 0 as placeholder.
# https://github.com/microsoft/DeepSpeed-MII/pull/311
# See https://github.com/microsoft/DeepSpeed-MII/pull/311
output.ttft = 0
st = time.perf_counter()
try:
async with session.post(url=request_func_input.api_url,
json=payload) as resp:
if resp.status == 200:
parsed_resp = await resp.json()
json=payload) as response:
if response.status == 200:
parsed_resp = await response.json()
output.latency = time.perf_counter() - st
output.generated_text = parsed_resp[0]["generated_text"]
output.generated_text = parsed_resp["text"][0]
output.success = True
else:
output.error = response.reason
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
@@ -221,7 +212,9 @@ async def async_request_openai_completions(
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith("v1/completions")
assert api_url.endswith(
"v1/completions"
), "OpenAI Completions API URL must end with 'v1/completions'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
@@ -243,15 +236,12 @@ async def async_request_openai_completions(
generated_text = ""
ttft = 0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk in response.content:
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
chunk = chunk.strip()
if not chunk:
continue
@@ -260,16 +250,33 @@ async def async_request_openai_completions(
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
body = json.loads(chunk)
generated_text += body["choices"][0]["text"]
data = json.loads(chunk)
if data["choices"][0]["text"]:
timestamp = time.perf_counter()
# First token
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# do not want to include as inter-token-latency
elif data.get("usage", None) is None:
output.itl.append(timestamp -
most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
output.generated_text = generated_text
output.success = True
output.latency = latency
else:
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
@@ -283,7 +290,7 @@ async def async_request_openai_chat_completions(
api_url = request_func_input.api_url
assert api_url.endswith(
"v1/chat/completions"
), "OpenAI Chat API URL must end with 'v1/chat/completions'."
), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
@@ -301,7 +308,7 @@ async def async_request_openai_chat_completions(
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
output = RequestFuncOutput()
@@ -310,15 +317,12 @@ async def async_request_openai_chat_completions(
generated_text = ""
ttft = 0
st = time.perf_counter()
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload,
headers=headers) as response:
if response.status == 200:
async for chunk in response.content:
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
chunk = chunk.strip()
if not chunk:
continue
@@ -327,18 +331,35 @@ async def async_request_openai_chat_completions(
if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
body = json.loads(chunk)
if "content" in body["choices"][0]["delta"]:
generated_text += body["choices"][0]["delta"][
timestamp = time.perf_counter()
data = json.loads(chunk)
if "content" in data["choices"][0]["delta"]:
# First token
if ttft == 0:
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp -
most_recent_timestamp)
generated_text += data["choices"][0]["delta"][
"content"]
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = latency
else:
output.error = response.reason
output.success = False
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
@@ -355,7 +376,8 @@ def remove_prefix(text: str, prefix: str) -> str:
ASYNC_REQUEST_FUNCS = {
"tgi": async_request_tgi,
"vllm": async_request_vllm,
"vllm": async_request_openai_completions,
"lmdeploy": async_request_openai_completions,
"deepspeed-mii": async_request_deepspeed_mii,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,