Support eos_token_id from generation_config.json (#4182)

This commit is contained in:
Simon Mo
2024-04-18 21:13:36 -07:00
committed by GitHub
parent 8a7a3e4436
commit a134ef6f5e
2 changed files with 30 additions and 3 deletions

View File

@@ -1,7 +1,7 @@
import time
from typing import Iterable, List, Optional, Type, Union
from transformers import PreTrainedTokenizer
from transformers import GenerationConfig, PreTrainedTokenizer
import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
@@ -34,6 +34,17 @@ logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
def _load_generation_config_dict(model_config: ModelConfig):
try:
return GenerationConfig.from_pretrained(
model_config.model,
revision=model_config.revision,
).to_diff_dict()
except OSError:
# Not found.
return {}
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
@@ -124,6 +135,8 @@ class LLMEngine:
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)
self.model_executor = executor_class(
model_config=model_config,
@@ -391,6 +404,8 @@ class LLMEngine:
# inject the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id
sampling_params.update_from_generation_config(
self.generation_config_fields)
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
@@ -435,7 +450,7 @@ class LLMEngine:
scheduled_seq_groups: List[SequenceGroup],
ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]:
"""Apply the model output to the sequences in the scheduled seq groups.
Returns RequestOutputs that can be returned to the client.
"""