[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)

This commit is contained in:
SangBin Cho
2024-08-18 17:57:20 -07:00
committed by GitHub
parent 200a2ffa6b
commit ff7ec82c4d
36 changed files with 722 additions and 346 deletions

View File

@@ -2,10 +2,10 @@
import copy
from enum import IntEnum
from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union
import msgspec
import torch
from pydantic import Field
from typing_extensions import Annotated
from vllm.logger import init_logger
@@ -33,7 +33,11 @@ first argument, and returns a modified tensor of logits
to sample from."""
class SamplingParams:
class SamplingParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True): # type: ignore[call-arg]
"""Sampling parameters for text generation.
Overall, we follow the sampling parameters from the OpenAI text completion
@@ -112,87 +116,73 @@ class SamplingParams:
(i.e., no truncation).
"""
def __init__(
self,
n: int = 1,
best_of: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
min_tokens: int = 0,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
detokenize: bool = True,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.repetition_penalty = repetition_penalty
if 0 < temperature < _MAX_TEMP:
n: int = 1
best_of: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1
min_p: float = 0.0
seed: Optional[int] = None
use_beam_search: bool = False
length_penalty: float = 1.0
early_stopping: Union[bool, str] = False
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
ignore_eos: bool = False
max_tokens: Optional[int] = 16
min_tokens: int = 0
logprobs: Optional[int] = None
prompt_logprobs: Optional[int] = None
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
detokenize: bool = True
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
# Optional[List[LogitsProcessor]] type. We use Any here because
# Optional[List[LogitsProcessor]] type is not supported by msgspec.
logits_processors: Optional[Any] = None
include_stop_str_in_output: bool = False
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
# The below fields are not supposed to be used as an input.
# They are set in post_init.
output_text_buffer_length: int = 0
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
def __post_init__(self) -> None:
self.best_of = self.best_of or self.n
if 0 < self.temperature < _MAX_TEMP:
logger.warning(
"temperature %s is less than %s, which may cause numerical "
"errors nan or inf in tensors. We have maxed it out to %s.",
temperature, _MAX_TEMP, _MAX_TEMP)
temperature = max(temperature, _MAX_TEMP)
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
if seed == -1:
self.temperature, _MAX_TEMP, _MAX_TEMP)
self.temperature = max(self.temperature, _MAX_TEMP)
if self.seed == -1:
self.seed = None
else:
self.seed = seed
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
if stop is None:
self.seed = self.seed
if self.stop is None:
self.stop = []
elif isinstance(stop, str):
self.stop = [stop]
elif isinstance(self.stop, str):
self.stop = [self.stop]
else:
self.stop = list(stop)
if stop_token_ids is None:
self.stop = list(self.stop)
if self.stop_token_ids is None:
self.stop_token_ids = []
else:
self.stop_token_ids = list(stop_token_ids)
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.min_tokens = min_tokens
self.logprobs = 1 if logprobs is True else logprobs
self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
self.detokenize = detokenize
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
self.stop_token_ids = list(self.stop_token_ids)
self.logprobs = 1 if self.logprobs is True else self.logprobs
self.prompt_logprobs = (1 if self.prompt_logprobs is True else
self.prompt_logprobs)
# Number of characters to hold back for stop string evaluation
# until sequence is finished.
if self.stop and not include_stop_str_in_output:
if self.stop and not self.include_stop_str_in_output:
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.output_text_buffer_length = 0
self._verify_args()
if self.use_beam_search:
@@ -206,11 +196,12 @@ class SamplingParams:
self.min_p = 0.0
self._verify_greedy_sampling()
# eos_token_id is added to this by the engine
self.all_stop_token_ids = set(self.stop_token_ids)
self._all_stop_token_ids = set(self.stop_token_ids)
def _verify_args(self) -> None:
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
assert isinstance(self.best_of, int)
if self.best_of < self.n:
raise ValueError(f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
@@ -257,6 +248,7 @@ class SamplingParams:
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
assert isinstance(self.stop, list)
if any(not stop_str for stop_str in self.stop):
raise ValueError("stop cannot contain an empty string.")
if self.stop and not self.detokenize:
@@ -290,6 +282,7 @@ class SamplingParams:
"default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None:
assert isinstance(self.best_of, int)
if self.best_of > 1:
raise ValueError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.")
@@ -303,7 +296,7 @@ class SamplingParams:
if model_eos_token_id is not None:
# Add the eos token id into the sampling_params to support
# min_tokens processing.
self.all_stop_token_ids.add(model_eos_token_id)
self._all_stop_token_ids.add(model_eos_token_id)
# Update eos_token_id for generation
if (eos_ids := generation_config.get("eos_token_id")) is not None:
@@ -315,7 +308,7 @@ class SamplingParams:
# purposes.
eos_ids.discard(model_eos_token_id)
if eos_ids:
self.all_stop_token_ids.update(eos_ids)
self._all_stop_token_ids.update(eos_ids)
if not self.ignore_eos:
eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids)
@@ -330,6 +323,10 @@ class SamplingParams:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM
@property
def all_stop_token_ids(self) -> Set[int]:
return self._all_stop_token_ids
def clone(self) -> "SamplingParams":
"""Deep copy excluding LogitsProcessor objects.