Support eos_token_id from generation_config.json (#4182)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from typing import Iterable, List, Optional, Type, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer
|
||||
|
||||
import vllm
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||
@@ -34,6 +34,17 @@ logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
def _load_generation_config_dict(model_config: ModelConfig):
|
||||
try:
|
||||
return GenerationConfig.from_pretrained(
|
||||
model_config.model,
|
||||
revision=model_config.revision,
|
||||
).to_diff_dict()
|
||||
except OSError:
|
||||
# Not found.
|
||||
return {}
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
|
||||
@@ -124,6 +135,8 @@ class LLMEngine:
|
||||
self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
self.seq_counter = Counter()
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
model_config)
|
||||
|
||||
self.model_executor = executor_class(
|
||||
model_config=model_config,
|
||||
@@ -391,6 +404,8 @@ class LLMEngine:
|
||||
# inject the eos token id into the sampling_params to support min_tokens
|
||||
# processing
|
||||
sampling_params.eos_token_id = seq.eos_token_id
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields)
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
@@ -435,7 +450,7 @@ class LLMEngine:
|
||||
scheduled_seq_groups: List[SequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
|
||||
"""Apply the model output to the sequences in the scheduled seq groups.
|
||||
|
||||
|
||||
Returns RequestOutputs that can be returned to the client.
|
||||
"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user