[Core] Switch Flat logprob control from environment variable to SamplingParams (#28914)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
@@ -225,7 +225,6 @@ if TYPE_CHECKING:
|
||||
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
||||
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
|
||||
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
||||
VLLM_FLAT_LOGPROBS: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@@ -1499,11 +1498,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT": env_with_choices(
|
||||
"VLLM_COMPILE_CACHE_SAVE_FORMAT", "binary", ["binary", "unpacked"]
|
||||
),
|
||||
# Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
|
||||
# the original list[dict[int, Logprob]] approach.
|
||||
# After enabled, PromptLogprobs and SampleLogprobs would populated as
|
||||
# FlatLogprobs.
|
||||
"VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
||||
@@ -5,8 +5,6 @@ from collections.abc import Iterable, Iterator, MutableSequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import overload
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
|
||||
# We use dataclass for now because it is used for
|
||||
# openai server output, and msgspec is not serializable.
|
||||
@@ -161,17 +159,17 @@ PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
|
||||
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
|
||||
|
||||
|
||||
def create_prompt_logprobs() -> PromptLogprobs:
|
||||
def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs:
|
||||
"""Creates a container to store prompt logprobs for a request"""
|
||||
logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
|
||||
logprobs = FlatLogprobs() if flat_logprobs else []
|
||||
# NOTE: logprob of first prompt token is None.
|
||||
logprobs.append(None)
|
||||
return logprobs
|
||||
|
||||
|
||||
def create_sample_logprobs() -> SampleLogprobs:
|
||||
def create_sample_logprobs(flat_logprobs: bool) -> SampleLogprobs:
|
||||
"""Creates a container to store decode logprobs for a request"""
|
||||
return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
|
||||
return FlatLogprobs() if flat_logprobs else []
|
||||
|
||||
|
||||
def append_logprobs_for_next_position(
|
||||
|
||||
@@ -204,6 +204,12 @@ class SamplingParams(
|
||||
prompt_logprobs: int | None = None
|
||||
"""Number of log probabilities to return per prompt token.
|
||||
When set to -1, return all `vocab_size` log probabilities."""
|
||||
flat_logprobs: bool = False
|
||||
"""Whether to return logprobs in flatten format (i.e. FlatLogprob)
|
||||
for better performance.
|
||||
NOTE: GC costs of FlatLogprobs is significantly smaller than
|
||||
list[dict[int, Logprob]]. After enabled, PromptLogprobs and
|
||||
SampleLogprobs would populated as FlatLogprobs."""
|
||||
# NOTE: This parameter is only exposed at the engine level for now.
|
||||
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
||||
# not support returning only a list of token IDs.
|
||||
|
||||
@@ -43,15 +43,22 @@ class LogprobsProcessor:
|
||||
tokenizer: AnyTokenizer | None,
|
||||
request: EngineCoreRequest,
|
||||
) -> "LogprobsProcessor":
|
||||
assert request.sampling_params is not None
|
||||
num_logprobs = request.sampling_params.logprobs
|
||||
num_prompt_logprobs = request.sampling_params.prompt_logprobs
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
num_logprobs = sampling_params.logprobs
|
||||
num_prompt_logprobs = sampling_params.prompt_logprobs
|
||||
return cls(
|
||||
tokenizer=tokenizer,
|
||||
cumulative_logprob=(None if num_logprobs is None else 0.0),
|
||||
logprobs=(None if num_logprobs is None else create_sample_logprobs()),
|
||||
logprobs=(
|
||||
None
|
||||
if num_logprobs is None
|
||||
else create_sample_logprobs(sampling_params.flat_logprobs)
|
||||
),
|
||||
prompt_logprobs=(
|
||||
None if num_prompt_logprobs is None else create_prompt_logprobs()
|
||||
None
|
||||
if num_prompt_logprobs is None
|
||||
else create_prompt_logprobs(sampling_params.flat_logprobs)
|
||||
),
|
||||
num_prompt_logprobs=num_prompt_logprobs,
|
||||
num_logprobs=num_logprobs,
|
||||
|
||||
Reference in New Issue
Block a user