[Bugfix] Support eos_token_id from config.json (#5954)

This commit is contained in:
Cyrus Leung
2024-06-29 19:19:02 +08:00
committed by GitHub
parent 329df38f1a
commit 51e971d39e
3 changed files with 67 additions and 11 deletions

View File

@@ -1,10 +1,10 @@
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer
from transformers import PreTrainedTokenizer
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ObservabilityConfig,
@@ -34,6 +34,7 @@ from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
@@ -46,16 +47,18 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig):
try:
return GenerationConfig.from_pretrained(
model_config.model,
revision=model_config.revision,
).to_diff_dict()
except OSError:
# Not found.
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)
if config is None:
return {}
return config.to_diff_dict()
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)