[Misc] FlattenLogprobs -> FlatLogprobs (#28335)

This commit is contained in:
Zhuohan Li
2025-11-10 19:41:23 -08:00
committed by GitHub
parent 57201a6a4c
commit 8d706cca90
4 changed files with 43 additions and 47 deletions

View File

@@ -223,7 +223,7 @@ if TYPE_CHECKING:
VLLM_GC_DEBUG: str = ""
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_FLATTEN_LOGPROBS: bool = False
VLLM_FLAT_LOGPROBS: bool = False
def get_default_cache_root():
@@ -1481,11 +1481,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
),
# Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
# the original list[dict[int, Logprob]] approach.
# After enabled, PromptLogprobs and SampleLogprobs would populated as
# FlattenLogprobs.
"VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))),
# FlatLogprobs.
"VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
}
# --8<-- [end:env-vars-definition]

View File

@@ -30,16 +30,16 @@ LogprobsOnePosition = dict[int, Logprob]
@dataclass
class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
"""
Flatten logprobs of a request into multiple primitive type lists.
Flat logprobs of a request into multiple primitive type lists.
Compared to list[dict[int, Logprob]], this data structure reduced GC
overhead significantly. As it flattened logprob information for
all positions and ranks in to multiple primitive type lists (i.e.
logprobs, token_ids, ranks per token_ids, decoded_tokens).
So regardless of the sequence length and top_logprobs setup,
FlattenLogprobs would only introduce a constant amount of objects.
FlatLogprobs would only introduce a constant amount of objects.
As each position might contains different amount of ranks,
start_indices_per_position would be used to access the logprob ranges
@@ -107,7 +107,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
def __getitem__(self, position: int) -> LogprobsOnePosition: ...
@overload
def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ...
def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...
def __getitem__(self, index: int | slice):
"""Extracts logprobs of a given position or slice"""
@@ -123,7 +123,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
elif isinstance(index, slice):
min_index = self.start_indices[index][0]
max_index = self.end_indices[index][-1]
return FlattenLogprobs(
return FlatLogprobs(
# Shift updated start_indices and end_indices to
# be 0-indexed
start_indices=[i - min_index for i in self.start_indices[index]],
@@ -137,13 +137,13 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
raise TypeError(f"Invalid index type: {type(index)}")
def __setitem__(self, item, value) -> None:
raise TypeError("Cannot set logprobs in FlattenLogprobs")
raise TypeError("Cannot set logprobs in FlatLogprobs")
def __delitem__(self, item) -> None:
raise TypeError("Cannot delete logprobs from FlattenLogprobs")
raise TypeError("Cannot delete logprobs from FlatLogprobs")
def insert(self, item) -> None:
raise TypeError("Cannot insert logprobs to FlattenLogprobs")
raise TypeError("Cannot insert logprobs to FlatLogprobs")
def __iter__(self) -> Iterator[LogprobsOnePosition]:
"""
@@ -156,14 +156,14 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
# {token_id -> logprob} per each sequence group. None if the corresponding
# sequence group doesn't require prompt logprob.
PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None]
PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
# {token_id -> logprob} for each sequence group.
SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition]
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
def create_prompt_logprobs() -> PromptLogprobs:
"""Creates a container to store prompt logprobs for a request"""
logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
# NOTE: logprob of first prompt token is None.
logprobs.append(None)
return logprobs
@@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs:
def create_sample_logprobs() -> SampleLogprobs:
"""Creates a container to store decode logprobs for a request"""
return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
def append_logprobs_for_next_position(
@@ -191,7 +191,7 @@ def append_logprobs_for_next_position(
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank,), topk_ranks)
if isinstance(request_logprobs, FlattenLogprobs):
if isinstance(request_logprobs, FlatLogprobs):
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
else:
request_logprobs.append(