[Core] Add engine option to return only deltas or final output (#7381)
This commit is contained in:
@@ -223,9 +223,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_texts = [""] * num_choices * num_prompts
|
||||
previous_text_lens = [0] * num_choices * num_prompts
|
||||
previous_num_tokens = [0] * num_choices * num_prompts
|
||||
has_echoed = [False] * num_choices * num_prompts
|
||||
num_prompt_tokens = [0] * num_prompts
|
||||
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
@@ -233,6 +234,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
prompt_text = res.prompt
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids)
|
||||
|
||||
delta_token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[Dict[
|
||||
int, Logprob]]]]
|
||||
@@ -244,6 +249,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and request.max_tokens == 0:
|
||||
assert prompt_token_ids is not None
|
||||
assert prompt_text is not None
|
||||
# only return the prompt
|
||||
delta_text = prompt_text
|
||||
@@ -252,6 +258,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
has_echoed[i] = True
|
||||
elif (request.echo and request.max_tokens > 0
|
||||
and not has_echoed[i]):
|
||||
assert prompt_token_ids is not None
|
||||
assert prompt_text is not None
|
||||
assert prompt_logprobs is not None
|
||||
# echo the prompt and first token
|
||||
@@ -266,11 +273,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
delta_token_ids = output.token_ids[
|
||||
previous_num_tokens[i]:]
|
||||
out_logprobs = output.logprobs[previous_num_tokens[
|
||||
i]:] if output.logprobs else None
|
||||
delta_text = output.text
|
||||
delta_token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert out_logprobs is not None, (
|
||||
@@ -280,13 +285,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
initial_text_offset=len(previous_texts[i]),
|
||||
initial_text_offset=previous_text_lens[i],
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
previous_text_lens[i] += len(output.text)
|
||||
previous_num_tokens[i] += len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
@@ -307,8 +312,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
and request.stream_options.include_usage):
|
||||
if (request.stream_options.continuous_usage_stats
|
||||
or output.finish_reason is not None):
|
||||
prompt_tokens = len(prompt_token_ids)
|
||||
completion_tokens = len(output.token_ids)
|
||||
prompt_tokens = num_prompt_tokens[prompt_idx]
|
||||
completion_tokens = previous_num_tokens[i]
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
@@ -356,6 +361,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
for final_res in final_res_batch:
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
assert prompt_token_ids is not None
|
||||
prompt_logprobs = final_res.prompt_logprobs
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
@@ -411,9 +417,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_generated_tokens += len(output.token_ids)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
num_generated_tokens += sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
|
||||
Reference in New Issue
Block a user