[V1] Aggregate chunked prompt logprobs in model runner (#14875)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-03-24 09:27:57 -07:00
committed by GitHub
parent 9cc645141d
commit 3aee6573dc
7 changed files with 68 additions and 44 deletions

View File

@@ -115,7 +115,6 @@ class LogprobsProcessor:
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the torch tensors.
# TODO(rob): experiment with doing this in EngineCore?
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()

View File

@@ -105,9 +105,7 @@ class RequestState:
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
if not finished and (self.is_prefilling or final_only):
if not finished and final_only:
# Only the final output is required in FINAL_ONLY mode.
return None
@@ -285,19 +283,7 @@ class OutputProcessor:
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
# TODO(andy): prompt logprobs + chunked prefill can
# result in engine core returning an output for a
# partial prefill (in order to send back partial
# prompt logprobs.) This breaks the invariant that
# process_outputs is only operating on engine core
# outputs associated with non-partial completions.
# Currently this is handled by having `is_prefilling`
# check for new decoded tokens, indicating that
# the completion is not partial.
#
# Follow up will aggregate partial prompt logprobs
# in the EngineCore.
req_state.is_prefilling = not new_token_ids
req_state.is_prefilling = False
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
@@ -306,8 +292,7 @@ class OutputProcessor:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request,
# if required.
# 3) Compute sample and prompt logprobs for request, if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects.