[Feature] Add load generation config from model (#11164)

Signed-off-by: liuyanyi <wolfsonliu@163.com>
Signed-off-by: Yanyi Liu <wolfsonliu@163.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Yanyi Liu
2024-12-19 18:50:38 +08:00
committed by GitHub
parent 98356735ac
commit 5aef49806d
10 changed files with 307 additions and 74 deletions

View File

@@ -27,7 +27,8 @@ from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, print_warning_once, random_uuid,
resolve_obj_by_qualname)
@@ -160,6 +161,7 @@ class ModelConfig:
logits processor qualified names that can be passed with the
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
"""
def compute_hash(self) -> str:
@@ -218,7 +220,8 @@ class ModelConfig:
disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None) -> None:
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
@@ -348,6 +351,8 @@ class ModelConfig:
self.pooler_config = self._init_pooler_config(override_pooler_config)
self.logits_processor_pattern = logits_processor_pattern
self.generation_config = generation_config
self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
@@ -813,6 +818,56 @@ class ModelConfig:
return self.multimodal_config
def try_get_generation_config(self) -> Dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
config = try_get_generation_config(
self.model,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
)
else:
config = try_get_generation_config(
self.generation_config,
trust_remote_code=self.trust_remote_code,
)
if config is None:
return {}
return config.to_diff_dict()
def get_diff_sampling_param(self) -> Dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
set, an empty dictionary is returned.
Returns:
Dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
"""
if self.generation_config is None:
# When generation_config is not set
return {}
config = self.try_get_generation_config()
available_params = [
"repetition_penalty",
"temperature",
"top_k",
"top_p",
"min_p",
]
if any(p in config for p in available_params):
diff_sampling_param = {
p: config.get(p)
for p in available_params if config.get(p) is not None
}
else:
diff_sampling_param = {}
return diff_sampling_param
@property
def is_encoder_decoder(self) -> bool:
"""Extract the HF encoder/decoder model flag."""

View File

@@ -197,6 +197,8 @@ class EngineArgs:
kv_transfer_config: Optional[KVTransferConfig] = None
generation_config: Optional[str] = None
def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model
@@ -942,6 +944,16 @@ class EngineArgs:
default="auto",
help='The worker class to use for distributed execution.')
parser.add_argument(
"--generation-config",
type=nullable_str,
default=None,
help="The folder path to the generation config. "
"Defaults to None, will use the default generation config in vLLM. "
"If set to 'auto', the generation config will be automatically "
"loaded from model. If set to a folder path, the generation config "
"will be loaded from the specified folder path.")
return parser
@classmethod
@@ -985,7 +997,8 @@ class EngineArgs:
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern)
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config)
def create_load_config(self) -> LoadConfig:
return LoadConfig(

View File

@@ -5,8 +5,8 @@ from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
@@ -52,7 +52,6 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
SequenceGroupOutput, SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
@@ -65,20 +64,6 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)
if config is None:
return {}
return config.to_diff_dict()
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
@@ -274,8 +259,8 @@ class LLMEngine:
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
self.model_config)
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,

View File

@@ -258,6 +258,13 @@ class LLM:
else:
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
def get_default_sampling_params(self) -> SamplingParams:
diff_sampling_param = (
self.llm_engine.model_config.get_diff_sampling_param())
if diff_sampling_param:
return SamplingParams.from_optional(**diff_sampling_param)
return SamplingParams()
@overload
def generate(
self,
@@ -441,7 +448,7 @@ class LLM:
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
sampling_params = self.get_default_sampling_params()
self._validate_and_add_requests(
prompts=parsed_prompts,

View File

@@ -211,8 +211,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
temperature: Optional[float] = None
top_p: Optional[float] = None
tools: Optional[List[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
ChatCompletionNamedToolChoiceParam]] = "none"
@@ -224,9 +224,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: bool = False
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
top_k: Optional[int] = None
min_p: Optional[float] = None
repetition_penalty: Optional[float] = None
length_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
@@ -348,15 +348,32 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
# Default sampling parameters for chat completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
return BeamSearchParams(
beam_width=n,
@@ -367,13 +384,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(
self, default_max_tokens: int,
logits_processor_pattern: Optional[str]) -> SamplingParams:
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
@@ -403,11 +443,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
@@ -584,15 +624,15 @@ class CompletionRequest(OpenAIBaseModel):
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
temperature: Optional[float] = None
top_p: Optional[float] = None
user: Optional[str] = None
# doc: begin-completion-sampling-params
use_beam_search: bool = False
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
top_k: Optional[int] = None
min_p: Optional[float] = None
repetition_penalty: Optional[float] = None
length_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
@@ -669,14 +709,30 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
# Default sampling parameters for completion requests
_DEFAULT_SAMPLING_PARAMS: dict = {
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
def to_beam_search_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None
) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get("temperature", 1.0)
return BeamSearchParams(
beam_width=n,
@@ -687,12 +743,35 @@ class CompletionRequest(OpenAIBaseModel):
include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(
self, default_max_tokens: int,
logits_processor_pattern: Optional[str]) -> SamplingParams:
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
)
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.logprobs
@@ -718,11 +797,11 @@ class CompletionRequest(OpenAIBaseModel):
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,

View File

@@ -91,6 +91,10 @@ class OpenAIServingChat(OpenAIServing):
"been registered") from e
self.enable_prompt_tokens_details = enable_prompt_tokens_details
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info("Overwriting default chat sampling param with: %s",
diff_sampling_param)
async def create_chat_completion(
self,
@@ -191,13 +195,17 @@ class OpenAIServingChat(OpenAIServing):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
# Build default sampling params
default_sampling_params = (
self.model_config.get_diff_sampling_param())
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
default_max_tokens, default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern)
self.model_config.logits_processor_pattern,
default_sampling_params)
self._log_inputs(request_id,
request_prompts[i],

View File

@@ -55,6 +55,11 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_adapters=prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info(
"Overwriting default completion sampling param with: %s",
diff_sampling_param)
async def create_completion(
self,
@@ -118,13 +123,17 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
# Build default sampling params
default_sampling_params = (
self.model_config.get_diff_sampling_param())
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
default_max_tokens, default_sampling_params)
else:
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern)
self.model_config.logits_processor_pattern,
default_sampling_params)
request_id_item = f"{request_id}-{i}"

View File

@@ -1,5 +1,5 @@
import time
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from typing import Mapping, Optional, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
@@ -12,7 +12,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
@@ -34,8 +33,8 @@ class Processor:
self.lora_config = lora_config
self.tokenizer = tokenizer
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.generation_config_fields = model_config.try_get_generation_config(
)
self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer,
mm_registry)
@@ -181,16 +180,3 @@ class Processor:
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)
if config is None:
return {}
return config.to_diff_dict()