Raise error for long prompt (#273)

This commit is contained in:
Lily Liu
2023-06-30 18:48:49 -07:00
committed by GitHub
parent 598dc4b79a
commit dafd924c1f
5 changed files with 42 additions and 11 deletions

View File

@@ -123,8 +123,12 @@ class EngineArgs:
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray)
max_seq_len = min(
self.max_num_batched_tokens,
getattr(model_config.hf_config, "max_position_embeddings",
float("inf")))
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs)
self.max_num_seqs, max_seq_len)
return model_config, cache_config, parallel_config, scheduler_config

View File

@@ -226,8 +226,8 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups):
# Nothing to do.
return []
@@ -251,7 +251,7 @@ class LLMEngine:
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in seq_groups:
for seq_group in seq_groups + ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
return request_outputs
@@ -288,6 +288,12 @@ class LLMEngine:
if stopped:
continue
# Check if the sequence has reached max_seq_len.
if (seq.get_len() >=
self.scheduler.scheduler_config.max_seq_len):
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
self.scheduler.free_seq(