[Core] Add engine option to return only deltas or final output (#7381)

This commit is contained in:
Nick Hill
2024-09-12 20:02:00 +01:00
committed by GitHub
parent a6c0f3658d
commit 551ce01078
10 changed files with 371 additions and 137 deletions

View File

@@ -223,9 +223,10 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts
previous_text_lens = [0] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
num_prompt_tokens = [0] * num_prompts
try:
async for prompt_idx, res in result_generator:
@@ -233,6 +234,10 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt
# Prompt details are excluded from later streamed outputs
if res.prompt_token_ids is not None:
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[
int, Logprob]]]]
@@ -244,6 +249,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
assert prompt_token_ids is not None
assert prompt_text is not None
# only return the prompt
delta_text = prompt_text
@@ -252,6 +258,7 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
assert prompt_token_ids is not None
assert prompt_text is not None
assert prompt_logprobs is not None
# echo the prompt and first token
@@ -266,11 +273,9 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[
previous_num_tokens[i]:]
out_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None
delta_text = output.text
delta_token_ids = output.token_ids
out_logprobs = output.logprobs
if request.logprobs is not None:
assert out_logprobs is not None, (
@@ -280,13 +285,13 @@ class OpenAIServingCompletion(OpenAIServing):
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer,
initial_text_offset=len(previous_texts[i]),
initial_text_offset=previous_text_lens[i],
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
previous_text_lens[i] += len(output.text)
previous_num_tokens[i] += len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
@@ -307,8 +312,8 @@ class OpenAIServingCompletion(OpenAIServing):
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None):
prompt_tokens = len(prompt_token_ids)
completion_tokens = len(output.token_ids)
prompt_tokens = num_prompt_tokens[prompt_idx]
completion_tokens = previous_num_tokens[i]
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
@@ -356,6 +361,7 @@ class OpenAIServingCompletion(OpenAIServing):
for final_res in final_res_batch:
prompt_token_ids = final_res.prompt_token_ids
assert prompt_token_ids is not None
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
@@ -411,9 +417,9 @@ class OpenAIServingCompletion(OpenAIServing):
)
choices.append(choice_data)
num_generated_tokens += len(output.token_ids)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,