[Feature] Add load generation config from model (#11164)
Signed-off-by: liuyanyi <wolfsonliu@163.com> Signed-off-by: Yanyi Liu <wolfsonliu@163.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -197,6 +197,8 @@ class EngineArgs:
|
||||
|
||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||
|
||||
generation_config: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.tokenizer:
|
||||
self.tokenizer = self.model
|
||||
@@ -942,6 +944,16 @@ class EngineArgs:
|
||||
default="auto",
|
||||
help='The worker class to use for distributed execution.')
|
||||
|
||||
parser.add_argument(
|
||||
"--generation-config",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The folder path to the generation config. "
|
||||
"Defaults to None, will use the default generation config in vLLM. "
|
||||
"If set to 'auto', the generation config will be automatically "
|
||||
"loaded from model. If set to a folder path, the generation config "
|
||||
"will be loaded from the specified folder path.")
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@@ -985,7 +997,8 @@ class EngineArgs:
|
||||
disable_mm_preprocessor_cache=self.disable_mm_preprocessor_cache,
|
||||
override_neuron_config=self.override_neuron_config,
|
||||
override_pooler_config=self.override_pooler_config,
|
||||
logits_processor_pattern=self.logits_processor_pattern)
|
||||
logits_processor_pattern=self.logits_processor_pattern,
|
||||
generation_config=self.generation_config)
|
||||
|
||||
def create_load_config(self) -> LoadConfig:
|
||||
return LoadConfig(
|
||||
|
||||
@@ -5,8 +5,8 @@ from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||
Iterable, List, Mapping, NamedTuple, Optional)
|
||||
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
|
||||
List, Mapping, NamedTuple, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, cast, overload
|
||||
|
||||
@@ -52,7 +52,6 @@ from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
|
||||
SequenceGroupOutput, SequenceStatus)
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
@@ -65,20 +64,6 @@ from vllm.version import __version__ as VLLM_VERSION
|
||||
logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
config = try_get_generation_config(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {}
|
||||
|
||||
return config.to_diff_dict()
|
||||
|
||||
|
||||
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
|
||||
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
|
||||
|
||||
@@ -274,8 +259,8 @@ class LLMEngine:
|
||||
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
self.seq_counter = Counter()
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
self.model_config)
|
||||
self.generation_config_fields = (
|
||||
self.model_config.try_get_generation_config())
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer,
|
||||
|
||||
Reference in New Issue
Block a user