[Perf] Introduce FlattenLogprobs to store logprobs results to reduce GC overhead (#28171)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Jialin Ouyang
2025-11-07 00:27:12 -08:00
committed by GitHub
parent 21b82f4ea2
commit ccd98b59c1
6 changed files with 534 additions and 125 deletions

View File

@@ -2,11 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Iterable
from dataclasses import dataclass
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs
from vllm.logprobs import (
PromptLogprobs,
SampleLogprobs,
append_logprobs_for_next_position,
create_prompt_logprobs,
create_sample_logprobs,
)
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer,
convert_ids_list_to_tokens,
@@ -44,9 +49,10 @@ class LogprobsProcessor:
return cls(
tokenizer=tokenizer,
cumulative_logprob=(None if num_logprobs is None else 0.0),
logprobs=(None if num_logprobs is None else []),
# NOTE: logprob of first prompt token is None.
prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
logprobs=(None if num_logprobs is None else create_sample_logprobs()),
prompt_logprobs=(
None if num_prompt_logprobs is None else create_prompt_logprobs()
),
num_prompt_logprobs=num_prompt_logprobs,
num_logprobs=num_logprobs,
)
@@ -80,15 +86,14 @@ class LogprobsProcessor:
sampled_token_logprob = logprobs[0]
self.cumulative_logprob += sampled_token_logprob
# Update with the Logprob dictionary for this pos.
self.logprobs.append(
self._make_logprob_dict(
logprobs,
token_ids,
decoded_tokens,
rank,
self.num_logprobs,
)
# Update with the Logprob container for this pos.
append_logprobs_for_next_position(
self.logprobs,
token_ids,
logprobs,
decoded_tokens,
rank,
self.num_logprobs,
)
def _update_prompt_logprobs(
@@ -136,15 +141,14 @@ class LogprobsProcessor:
NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
)
# Update with the Logprob dictionary for this pos.
self.prompt_logprobs.append(
self._make_logprob_dict(
prompt_logprobs[pos],
token_ids[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
self.num_prompt_logprobs,
)
# Update with the Logprob container for this pos.
append_logprobs_for_next_position(
self.prompt_logprobs,
token_ids[pos],
prompt_logprobs[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
self.num_prompt_logprobs,
)
def pop_prompt_logprobs(self) -> PromptLogprobs | None:
@@ -166,46 +170,6 @@ class LogprobsProcessor:
self.prompt_logprobs = []
return plp
@staticmethod
def _make_logprob_dict(
logprobs: list[float],
logprob_token_ids: list[int],
decoded_tokens: Iterable[str | None],
rank: int,
num_logprobs: int,
) -> dict[int, Logprob]:
"""Make a Logprob dictionary for a position.
Args:
logprobs: list of log probabilities
logprob_token_ids: list of top token ids
decoded_tokens: list of decoded top tokens
rank: rank of the sampled token
num_logprobs: number of logprobs requested
by the user (in addition to sampled logprob)
Returns:
dict[token id, Logprob]
"""
if num_logprobs == -1:
num_logprobs = len(logprobs)
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank,), topk_ranks)
return {
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(
logprob_token_ids, logprobs, ranks, decoded_tokens
)
}
def update_from_output(self, output: EngineCoreOutput) -> None:
if output.new_logprobs is not None:
self._update_sample_logprobs(output.new_logprobs)