[Core] Add engine option to return only deltas or final output (#7381)
This commit is contained in:
@@ -246,8 +246,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
||||
if request.add_generation_prompt:
|
||||
return self.response_role
|
||||
else:
|
||||
return request.messages[-1]["role"]
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self,
|
||||
@@ -264,15 +263,37 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_texts = [""] * num_choices
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
|
||||
num_prompt_tokens = 0
|
||||
|
||||
tool_parser: Optional[ToolParser] = self.tool_parser(
|
||||
tokenizer) if self.tool_parser else None
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
else:
|
||||
tool_choice_function_name = None
|
||||
|
||||
# Determine whether tools are in use with "auto" tool choice
|
||||
tool_choice_auto = (
|
||||
not tool_choice_function_name
|
||||
and self._should_stream_with_auto_tool_parsing(request))
|
||||
|
||||
all_previous_token_ids: Optional[List[List[int]]]
|
||||
if tool_choice_auto:
|
||||
# These are only required in "auto" tool choice case
|
||||
previous_texts = [""] * num_choices
|
||||
all_previous_token_ids = [[]] * num_choices
|
||||
else:
|
||||
previous_texts, all_previous_token_ids = None, None
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
@@ -305,10 +326,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and request.stream_options.include_usage):
|
||||
# if continuous usage stats are requested, add it
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
usage = UsageInfo(prompt_tokens=prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=prompt_tokens)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=num_prompt_tokens)
|
||||
chunk.usage = usage
|
||||
# otherwise don't
|
||||
else:
|
||||
@@ -344,12 +365,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request.stream_options.include_usage):
|
||||
if (request.stream_options.
|
||||
continuous_usage_stats):
|
||||
prompt_tokens = len(
|
||||
res.prompt_token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=prompt_tokens)
|
||||
total_tokens=num_prompt_tokens)
|
||||
chunk.usage = usage
|
||||
else:
|
||||
chunk.usage = None
|
||||
@@ -360,65 +379,66 @@ class OpenAIServingChat(OpenAIServing):
|
||||
first_iteration = False
|
||||
|
||||
for output in res.outputs:
|
||||
|
||||
i = output.index
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
|
||||
out_logprobs = output.logprobs[
|
||||
previous_num_tokens[i]:] if output.logprobs else None
|
||||
|
||||
if request.logprobs and request.top_logprobs is not None:
|
||||
assert out_logprobs is not None, (
|
||||
assert output.logprobs is not None, (
|
||||
"Did not output logprobs")
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
token_ids=output.token_ids,
|
||||
top_logprobs=output.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
delta_message: Optional[DeltaMessage] = None
|
||||
delta_text = output.text
|
||||
delta_message: Optional[DeltaMessage]
|
||||
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
if (request.tool_choice and type(request.tool_choice) is
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
if tool_choice_function_name:
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(function=DeltaFunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
name=tool_choice_function_name,
|
||||
arguments=delta_text),
|
||||
index=i)
|
||||
])
|
||||
|
||||
# handle streaming deltas for tools with "auto" tool choice
|
||||
elif (self._should_stream_with_auto_tool_parsing(request)
|
||||
and tool_parser):
|
||||
elif tool_choice_auto:
|
||||
assert previous_texts is not None
|
||||
assert all_previous_token_ids is not None
|
||||
assert tool_parser is not None
|
||||
#TODO optimize manipulation of these lists
|
||||
previous_text = previous_texts[i]
|
||||
previous_token_ids = all_previous_token_ids[i]
|
||||
current_text = previous_text + delta_text
|
||||
current_token_ids = previous_token_ids + list(
|
||||
output.token_ids)
|
||||
|
||||
delta_message = (
|
||||
tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_texts[i],
|
||||
current_text=output.text,
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids= \
|
||||
output.token_ids[
|
||||
:-1 * len(delta_token_ids)
|
||||
],
|
||||
current_token_ids=output.token_ids,
|
||||
delta_token_ids=delta_token_ids
|
||||
)
|
||||
)
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=output.token_ids))
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
all_previous_token_ids[i] = current_token_ids
|
||||
|
||||
# handle streaming just a content delta
|
||||
else:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
|
||||
# set the previous values for the next iteration
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
|
||||
# if the message delta is None (e.g. because it was a
|
||||
# "control token" for tool calls or the parser otherwise
|
||||
@@ -445,13 +465,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# handle usage stats if requested & if continuous
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if (request.stream_options.continuous_usage_stats):
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
completion_tokens = len(output.token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens +
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens,
|
||||
)
|
||||
chunk.usage = usage
|
||||
@@ -482,7 +501,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool_parser.prev_tool_call_arr[index].get(
|
||||
"arguments", {}))
|
||||
|
||||
# get what we've streamed so for for arguments
|
||||
# get what we've streamed so far for arguments
|
||||
# for the current tool
|
||||
actual_call = tool_parser.streamed_args_for_tool[
|
||||
index]
|
||||
@@ -500,7 +519,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
])
|
||||
|
||||
# Send the finish response for each request.n only once
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
@@ -518,13 +536,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
model=model_name)
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
if (request.stream_options.continuous_usage_stats):
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
if request.stream_options.continuous_usage_stats:
|
||||
completion_tokens = len(output.token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens +
|
||||
total_tokens=num_prompt_tokens +
|
||||
completion_tokens,
|
||||
)
|
||||
chunk.usage = usage
|
||||
@@ -538,10 +555,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# is sent, send the usage
|
||||
if (request.stream_options
|
||||
and request.stream_options.include_usage):
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=previous_num_tokens[i],
|
||||
total_tokens=prompt_tokens + previous_num_tokens[i],
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
@@ -680,6 +698,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
or "")
|
||||
choice.message.content = full_message
|
||||
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
@@ -789,9 +808,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return bool(
|
||||
# if there is a delta message that includes tool calls which
|
||||
# include a function that has arguments
|
||||
self.enable_auto_tools and self.tool_parser and delta_message
|
||||
output.finish_reason is not None
|
||||
and self.enable_auto_tools and self.tool_parser and delta_message
|
||||
and delta_message.tool_calls and delta_message.tool_calls[0]
|
||||
and delta_message.tool_calls[0].function
|
||||
and delta_message.tool_calls[0].function.arguments is not None
|
||||
and output.finish_reason is not None
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user