[Quality] Add code formatter and linter (#326)

This commit is contained in:
Zhuohan Li
2023-07-03 11:31:55 -07:00
committed by GitHub
parent 0ffded812a
commit d6fa1be3a8
47 changed files with 1547 additions and 617 deletions

View File

@@ -67,8 +67,7 @@ class LLMEngine:
f"download_dir={model_config.download_dir!r}, "
f"use_np_weights={model_config.use_np_weights}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"seed={model_config.seed})"
)
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
@@ -78,8 +77,8 @@ class LLMEngine:
self.log_stats = log_stats
self._verify_args()
self.tokenizer = get_tokenizer(model_config.tokenizer,
model_config.tokenizer_mode)
self.tokenizer = get_tokenizer(
model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode)
self.seq_counter = Counter()
# Create the parallel GPU workers.
@@ -129,8 +128,8 @@ class LLMEngine:
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
# FIXME(woosuk): Change to debug log.
logger.info(f'# GPU blocks: {num_gpu_blocks}, '
f'# CPU blocks: {num_cpu_blocks}')
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
@@ -152,7 +151,9 @@ class LLMEngine:
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM engine.
engine = cls(*engine_configs, distributed_init_method, devices,
engine = cls(*engine_configs,
distributed_init_method,
devices,
log_stats=not engine_args.disable_log_stats)
return engine
@@ -226,8 +227,10 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups):
(seq_group_metadata_list, scheduler_outputs,
ignored_seq_groups) = self.scheduler.schedule()
if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
and (not ignored_seq_groups)):
# Nothing to do.
return []
@@ -281,8 +284,8 @@ class LLMEngine:
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(seq,
SequenceStatus.FINISHED_STOPPED)
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
stopped = True
break
if stopped:
@@ -290,7 +293,7 @@ class LLMEngine:
# Check if the sequence has reached max_seq_len.
if (seq.get_len() >=
self.scheduler.scheduler_config.max_seq_len):
self.scheduler.scheduler_config.max_seq_len):
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
@@ -302,15 +305,15 @@ class LLMEngine:
# Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
self.scheduler.free_seq(seq,
SequenceStatus.FINISHED_STOPPED)
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
continue
def _run_workers(
self,
method: str,
get_all_outputs: bool = False,
*args,
get_all_outputs: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""