[Bugfix] Support eos_token_id from config.json (#5954)
This commit is contained in:
@@ -1,10 +1,10 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, TypeVar, Union
|
||||
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||
LoRAConfig, ModelConfig, ObservabilityConfig,
|
||||
@@ -34,6 +34,7 @@ from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
SequenceStatus)
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
|
||||
init_tracer)
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
get_tokenizer_group)
|
||||
@@ -46,16 +47,18 @@ logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
def _load_generation_config_dict(model_config: ModelConfig):
|
||||
try:
|
||||
return GenerationConfig.from_pretrained(
|
||||
model_config.model,
|
||||
revision=model_config.revision,
|
||||
).to_diff_dict()
|
||||
except OSError:
|
||||
# Not found.
|
||||
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
config = try_get_generation_config(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
revision=model_config.revision,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
return {}
|
||||
|
||||
return config.to_diff_dict()
|
||||
|
||||
|
||||
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user