[Core] Add engine option to return only deltas or final output (#7381)

This commit is contained in:
Nick Hill
2024-09-12 20:02:00 +01:00
committed by GitHub
parent a6c0f3658d
commit 551ce01078
10 changed files with 371 additions and 137 deletions

View File

@@ -246,8 +246,7 @@ class OpenAIServingChat(OpenAIServing):
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
return self.response_role
else:
return request.messages[-1]["role"]
return request.messages[-1]["role"]
async def chat_completion_stream_generator(
self,
@@ -264,15 +263,37 @@ class OpenAIServingChat(OpenAIServing):
# Send response for each token for each request.n (index)
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0
tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
else:
tool_choice_function_name = None
# Determine whether tools are in use with "auto" tool choice
tool_choice_auto = (
not tool_choice_function_name
and self._should_stream_with_auto_tool_parsing(request))
all_previous_token_ids: Optional[List[List[int]]]
if tool_choice_auto:
# These are only required in "auto" tool choice case
previous_texts = [""] * num_choices
all_previous_token_ids = [[]] * num_choices
else:
previous_texts, all_previous_token_ids = None, None
try:
async for res in result_generator:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
@@ -305,10 +326,10 @@ class OpenAIServingChat(OpenAIServing):
and request.stream_options.include_usage):
# if continuous usage stats are requested, add it
if request.stream_options.continuous_usage_stats:
prompt_tokens = len(res.prompt_token_ids)
usage = UsageInfo(prompt_tokens=prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=num_prompt_tokens)
chunk.usage = usage
# otherwise don't
else:
@@ -344,12 +365,10 @@ class OpenAIServingChat(OpenAIServing):
request.stream_options.include_usage):
if (request.stream_options.
continuous_usage_stats):
prompt_tokens = len(
res.prompt_token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
prompt_tokens=num_prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens)
total_tokens=num_prompt_tokens)
chunk.usage = usage
else:
chunk.usage = None
@@ -360,65 +379,66 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = False
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
out_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs and request.top_logprobs is not None:
assert out_logprobs is not None, (
assert output.logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
token_ids=output.token_ids,
top_logprobs=output.logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
logprobs = None
delta_text = output.text[len(previous_texts[i]):]
delta_message: Optional[DeltaMessage] = None
delta_text = output.text
delta_message: Optional[DeltaMessage]
# handle streaming deltas for tools with named tool_choice
if (request.tool_choice and type(request.tool_choice) is
ChatCompletionNamedToolChoiceParam):
if tool_choice_function_name:
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=request.tool_choice.function.name,
name=tool_choice_function_name,
arguments=delta_text),
index=i)
])
# handle streaming deltas for tools with "auto" tool choice
elif (self._should_stream_with_auto_tool_parsing(request)
and tool_parser):
elif tool_choice_auto:
assert previous_texts is not None
assert all_previous_token_ids is not None
assert tool_parser is not None
#TODO optimize manipulation of these lists
previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + list(
output.token_ids)
delta_message = (
tool_parser.extract_tool_calls_streaming(
previous_text=previous_texts[i],
current_text=output.text,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids= \
output.token_ids[
:-1 * len(delta_token_ids)
],
current_token_ids=output.token_ids,
delta_token_ids=delta_token_ids
)
)
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids))
# update the previous values for the next iteration
previous_texts[i] = current_text
all_previous_token_ids[i] = current_token_ids
# handle streaming just a content delta
else:
delta_message = DeltaMessage(content=delta_text)
# set the previous values for the next iteration
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
previous_num_tokens[i] += len(output.token_ids)
# if the message delta is None (e.g. because it was a
# "control token" for tool calls or the parser otherwise
@@ -445,13 +465,12 @@ class OpenAIServingChat(OpenAIServing):
# handle usage stats if requested & if continuous
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
prompt_tokens = len(res.prompt_token_ids)
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens +
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
@@ -482,7 +501,7 @@ class OpenAIServingChat(OpenAIServing):
tool_parser.prev_tool_call_arr[index].get(
"arguments", {}))
# get what we've streamed so for for arguments
# get what we've streamed so far for arguments
# for the current tool
actual_call = tool_parser.streamed_args_for_tool[
index]
@@ -500,7 +519,6 @@ class OpenAIServingChat(OpenAIServing):
])
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=delta_message,
@@ -518,13 +536,12 @@ class OpenAIServingChat(OpenAIServing):
model=model_name)
if (request.stream_options
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats):
prompt_tokens = len(res.prompt_token_ids)
if request.stream_options.continuous_usage_stats:
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens +
total_tokens=num_prompt_tokens +
completion_tokens,
)
chunk.usage = usage
@@ -538,10 +555,11 @@ class OpenAIServingChat(OpenAIServing):
# is sent, send the usage
if (request.stream_options
and request.stream_options.include_usage):
completion_tokens = previous_num_tokens[i]
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens + previous_num_tokens[i],
prompt_tokens=num_prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=num_prompt_tokens + completion_tokens,
)
final_usage_chunk = ChatCompletionStreamResponse(
@@ -680,6 +698,7 @@ class OpenAIServingChat(OpenAIServing):
or "")
choice.message.content = full_message
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
@@ -789,9 +808,9 @@ class OpenAIServingChat(OpenAIServing):
return bool(
# if there is a delta message that includes tool calls which
# include a function that has arguments
self.enable_auto_tools and self.tool_parser and delta_message
output.finish_reason is not None
and self.enable_auto_tools and self.tool_parser and delta_message
and delta_message.tool_calls and delta_message.tool_calls[0]
and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None
and output.finish_reason is not None
)