Make initialization of tokenizer and detokenizer optional (#3748)

Co-authored-by: Yun Ding <yunding@nvidia.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
GeauxEric
2024-04-21 15:06:46 -07:00
committed by GitHub
parent 7f2593b164
commit a37d815b83
6 changed files with 68 additions and 12 deletions

View File

@@ -100,6 +100,7 @@ class LLMEngine:
f"model={model_config.model!r}, "
f"speculative_config={speculative_config!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"skip_tokenizer_init={model_config.skip_tokenizer_init}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, "
f"tokenizer_revision={model_config.tokenizer_revision}, "
@@ -132,8 +133,14 @@ class LLMEngine:
self.decoding_config = decoding_config or DecodingConfig()
self.log_stats = log_stats
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
if not self.model_config.skip_tokenizer_init:
self.tokenizer: BaseTokenizerGroup
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
else:
self.detokenizer = None
self.tokenizer = None
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
model_config)
@@ -187,9 +194,10 @@ class LLMEngine:
parallel_config.disable_custom_all_reduce,
})
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
@@ -296,7 +304,7 @@ class LLMEngine:
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs)
def _verify_args(self) -> None:
@@ -393,8 +401,13 @@ class LLMEngine:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
eos_token_id = None
if self.tokenizer:
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
else:
logger.warning("Use None for EOS token id because tokenizer is "
"not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)