[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)
This commit is contained in:
@@ -2,10 +2,10 @@
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@@ -33,7 +33,11 @@ first argument, and returns a modified tensor of logits
|
||||
to sample from."""
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
class SamplingParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True): # type: ignore[call-arg]
|
||||
"""Sampling parameters for text generation.
|
||||
|
||||
Overall, we follow the sampling parameters from the OpenAI text completion
|
||||
@@ -112,87 +116,73 @@ class SamplingParams:
|
||||
(i.e., no truncation).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n: int = 1,
|
||||
best_of: Optional[int] = None,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
seed: Optional[int] = None,
|
||||
use_beam_search: bool = False,
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: Optional[int] = 16,
|
||||
min_tokens: int = 0,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
detokenize: bool = True,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
) -> None:
|
||||
self.n = n
|
||||
self.best_of = best_of if best_of is not None else n
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.repetition_penalty = repetition_penalty
|
||||
if 0 < temperature < _MAX_TEMP:
|
||||
n: int = 1
|
||||
best_of: Optional[int] = None
|
||||
presence_penalty: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
seed: Optional[int] = None
|
||||
use_beam_search: bool = False
|
||||
length_penalty: float = 1.0
|
||||
early_stopping: Union[bool, str] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
ignore_eos: bool = False
|
||||
max_tokens: Optional[int] = 16
|
||||
min_tokens: int = 0
|
||||
logprobs: Optional[int] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# NOTE: This parameter is only exposed at the engine level for now.
|
||||
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
||||
# not support returning only a list of token IDs.
|
||||
detokenize: bool = True
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
# Optional[List[LogitsProcessor]] type. We use Any here because
|
||||
# Optional[List[LogitsProcessor]] type is not supported by msgspec.
|
||||
logits_processors: Optional[Any] = None
|
||||
include_stop_str_in_output: bool = False
|
||||
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
|
||||
|
||||
# The below fields are not supposed to be used as an input.
|
||||
# They are set in post_init.
|
||||
output_text_buffer_length: int = 0
|
||||
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.best_of = self.best_of or self.n
|
||||
if 0 < self.temperature < _MAX_TEMP:
|
||||
logger.warning(
|
||||
"temperature %s is less than %s, which may cause numerical "
|
||||
"errors nan or inf in tensors. We have maxed it out to %s.",
|
||||
temperature, _MAX_TEMP, _MAX_TEMP)
|
||||
temperature = max(temperature, _MAX_TEMP)
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.min_p = min_p
|
||||
if seed == -1:
|
||||
self.temperature, _MAX_TEMP, _MAX_TEMP)
|
||||
self.temperature = max(self.temperature, _MAX_TEMP)
|
||||
if self.seed == -1:
|
||||
self.seed = None
|
||||
else:
|
||||
self.seed = seed
|
||||
self.use_beam_search = use_beam_search
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
if stop is None:
|
||||
self.seed = self.seed
|
||||
if self.stop is None:
|
||||
self.stop = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
elif isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
else:
|
||||
self.stop = list(stop)
|
||||
if stop_token_ids is None:
|
||||
self.stop = list(self.stop)
|
||||
if self.stop_token_ids is None:
|
||||
self.stop_token_ids = []
|
||||
else:
|
||||
self.stop_token_ids = list(stop_token_ids)
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.min_tokens = min_tokens
|
||||
self.logprobs = 1 if logprobs is True else logprobs
|
||||
self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
|
||||
# NOTE: This parameter is only exposed at the engine level for now.
|
||||
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
||||
# not support returning only a list of token IDs.
|
||||
self.detokenize = detokenize
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.logits_processors = logits_processors
|
||||
self.include_stop_str_in_output = include_stop_str_in_output
|
||||
self.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
self.stop_token_ids = list(self.stop_token_ids)
|
||||
self.logprobs = 1 if self.logprobs is True else self.logprobs
|
||||
self.prompt_logprobs = (1 if self.prompt_logprobs is True else
|
||||
self.prompt_logprobs)
|
||||
|
||||
# Number of characters to hold back for stop string evaluation
|
||||
# until sequence is finished.
|
||||
if self.stop and not include_stop_str_in_output:
|
||||
if self.stop and not self.include_stop_str_in_output:
|
||||
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
|
||||
else:
|
||||
self.output_text_buffer_length = 0
|
||||
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
@@ -206,11 +196,12 @@ class SamplingParams:
|
||||
self.min_p = 0.0
|
||||
self._verify_greedy_sampling()
|
||||
# eos_token_id is added to this by the engine
|
||||
self.all_stop_token_ids = set(self.stop_token_ids)
|
||||
self._all_stop_token_ids = set(self.stop_token_ids)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||
assert isinstance(self.best_of, int)
|
||||
if self.best_of < self.n:
|
||||
raise ValueError(f"best_of must be greater than or equal to n, "
|
||||
f"got n={self.n} and best_of={self.best_of}.")
|
||||
@@ -257,6 +248,7 @@ class SamplingParams:
|
||||
and self.truncate_prompt_tokens < 1):
|
||||
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
|
||||
f"got {self.truncate_prompt_tokens}")
|
||||
assert isinstance(self.stop, list)
|
||||
if any(not stop_str for stop_str in self.stop):
|
||||
raise ValueError("stop cannot contain an empty string.")
|
||||
if self.stop and not self.detokenize:
|
||||
@@ -290,6 +282,7 @@ class SamplingParams:
|
||||
"default value of 1.0 when not using beam search.")
|
||||
|
||||
def _verify_greedy_sampling(self) -> None:
|
||||
assert isinstance(self.best_of, int)
|
||||
if self.best_of > 1:
|
||||
raise ValueError("best_of must be 1 when using greedy sampling."
|
||||
f"Got {self.best_of}.")
|
||||
@@ -303,7 +296,7 @@ class SamplingParams:
|
||||
if model_eos_token_id is not None:
|
||||
# Add the eos token id into the sampling_params to support
|
||||
# min_tokens processing.
|
||||
self.all_stop_token_ids.add(model_eos_token_id)
|
||||
self._all_stop_token_ids.add(model_eos_token_id)
|
||||
|
||||
# Update eos_token_id for generation
|
||||
if (eos_ids := generation_config.get("eos_token_id")) is not None:
|
||||
@@ -315,7 +308,7 @@ class SamplingParams:
|
||||
# purposes.
|
||||
eos_ids.discard(model_eos_token_id)
|
||||
if eos_ids:
|
||||
self.all_stop_token_ids.update(eos_ids)
|
||||
self._all_stop_token_ids.update(eos_ids)
|
||||
if not self.ignore_eos:
|
||||
eos_ids.update(self.stop_token_ids)
|
||||
self.stop_token_ids = list(eos_ids)
|
||||
@@ -330,6 +323,10 @@ class SamplingParams:
|
||||
return SamplingType.RANDOM_SEED
|
||||
return SamplingType.RANDOM
|
||||
|
||||
@property
|
||||
def all_stop_token_ids(self) -> Set[int]:
|
||||
return self._all_stop_token_ids
|
||||
|
||||
def clone(self) -> "SamplingParams":
|
||||
"""Deep copy excluding LogitsProcessor objects.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user