[Misc] FlattenLogprobs -> FlatLogprobs (#28335)
This commit is contained in:
@@ -223,7 +223,7 @@ if TYPE_CHECKING:
|
||||
VLLM_GC_DEBUG: str = ""
|
||||
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
||||
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
||||
VLLM_FLATTEN_LOGPROBS: bool = False
|
||||
VLLM_FLAT_LOGPROBS: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@@ -1481,11 +1481,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
|
||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
|
||||
),
|
||||
# Flag to enable FlattenLogprobs whose GC overhead is significantly smaller than
|
||||
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
|
||||
# the original list[dict[int, Logprob]] approach.
|
||||
# After enabled, PromptLogprobs and SampleLogprobs would populated as
|
||||
# FlattenLogprobs.
|
||||
"VLLM_FLATTEN_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLATTEN_LOGPROBS", "0"))),
|
||||
# FlatLogprobs.
|
||||
"VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
||||
@@ -30,16 +30,16 @@ LogprobsOnePosition = dict[int, Logprob]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
"""
|
||||
Flatten logprobs of a request into multiple primitive type lists.
|
||||
Flat logprobs of a request into multiple primitive type lists.
|
||||
|
||||
Compared to list[dict[int, Logprob]], this data structure reduced GC
|
||||
overhead significantly. As it flattened logprob information for
|
||||
all positions and ranks in to multiple primitive type lists (i.e.
|
||||
logprobs, token_ids, ranks per token_ids, decoded_tokens).
|
||||
So regardless of the sequence length and top_logprobs setup,
|
||||
FlattenLogprobs would only introduce a constant amount of objects.
|
||||
FlatLogprobs would only introduce a constant amount of objects.
|
||||
|
||||
As each position might contains different amount of ranks,
|
||||
start_indices_per_position would be used to access the logprob ranges
|
||||
@@ -107,7 +107,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
def __getitem__(self, position: int) -> LogprobsOnePosition: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice, /) -> "FlattenLogprobs": ...
|
||||
def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...
|
||||
|
||||
def __getitem__(self, index: int | slice):
|
||||
"""Extracts logprobs of a given position or slice"""
|
||||
@@ -123,7 +123,7 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
elif isinstance(index, slice):
|
||||
min_index = self.start_indices[index][0]
|
||||
max_index = self.end_indices[index][-1]
|
||||
return FlattenLogprobs(
|
||||
return FlatLogprobs(
|
||||
# Shift updated start_indices and end_indices to
|
||||
# be 0-indexed
|
||||
start_indices=[i - min_index for i in self.start_indices[index]],
|
||||
@@ -137,13 +137,13 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
raise TypeError(f"Invalid index type: {type(index)}")
|
||||
|
||||
def __setitem__(self, item, value) -> None:
|
||||
raise TypeError("Cannot set logprobs in FlattenLogprobs")
|
||||
raise TypeError("Cannot set logprobs in FlatLogprobs")
|
||||
|
||||
def __delitem__(self, item) -> None:
|
||||
raise TypeError("Cannot delete logprobs from FlattenLogprobs")
|
||||
raise TypeError("Cannot delete logprobs from FlatLogprobs")
|
||||
|
||||
def insert(self, item) -> None:
|
||||
raise TypeError("Cannot insert logprobs to FlattenLogprobs")
|
||||
raise TypeError("Cannot insert logprobs to FlatLogprobs")
|
||||
|
||||
def __iter__(self) -> Iterator[LogprobsOnePosition]:
|
||||
"""
|
||||
@@ -156,14 +156,14 @@ class FlattenLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
|
||||
# {token_id -> logprob} per each sequence group. None if the corresponding
|
||||
# sequence group doesn't require prompt logprob.
|
||||
PromptLogprobs = FlattenLogprobs | list[LogprobsOnePosition | None]
|
||||
PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
|
||||
# {token_id -> logprob} for each sequence group.
|
||||
SampleLogprobs = FlattenLogprobs | list[LogprobsOnePosition]
|
||||
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
|
||||
|
||||
|
||||
def create_prompt_logprobs() -> PromptLogprobs:
|
||||
"""Creates a container to store prompt logprobs for a request"""
|
||||
logprobs = FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
|
||||
logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
|
||||
# NOTE: logprob of first prompt token is None.
|
||||
logprobs.append(None)
|
||||
return logprobs
|
||||
@@ -171,7 +171,7 @@ def create_prompt_logprobs() -> PromptLogprobs:
|
||||
|
||||
def create_sample_logprobs() -> SampleLogprobs:
|
||||
"""Creates a container to store decode logprobs for a request"""
|
||||
return FlattenLogprobs() if envs.VLLM_FLATTEN_LOGPROBS else []
|
||||
return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
|
||||
|
||||
|
||||
def append_logprobs_for_next_position(
|
||||
@@ -191,7 +191,7 @@ def append_logprobs_for_next_position(
|
||||
topk_ranks = range(1, num_logprobs + 1)
|
||||
ranks = itertools.chain((rank,), topk_ranks)
|
||||
|
||||
if isinstance(request_logprobs, FlattenLogprobs):
|
||||
if isinstance(request_logprobs, FlatLogprobs):
|
||||
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
|
||||
else:
|
||||
request_logprobs.append(
|
||||
|
||||
Reference in New Issue
Block a user