Fix various issues of async servers (#135)

This commit is contained in:
Zhuohan Li
2023-06-05 23:44:50 +08:00
committed by GitHub
parent 8274ca23ac
commit 1a956e136b
11 changed files with 289 additions and 121 deletions

View File

@@ -1,11 +1,6 @@
import time
from typing import Any, List, Optional
try:
import ray
except ImportError:
ray = None
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from cacheflow.core.scheduler import Scheduler
@@ -13,7 +8,7 @@ from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.server.ray_utils import ray, initialize_cluster
from cacheflow.server.tokenizer_utils import (get_tokenizer,
detokenize_incrementally)
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
@@ -62,7 +57,7 @@ class LLMServer:
assert len(stage_devices) == 1, "Only support one stage for now."
for rank, node_resource, _ in stage_devices[0]:
worker_cls = Worker
if self.parallel_config.use_ray:
if self.parallel_config.worker_use_ray:
worker_cls = ray.remote(
num_cpus=0,
num_gpus=1,
@@ -152,6 +147,9 @@ class LLMServer:
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: str) -> None:
self.scheduler.abort_seq_group(request_id)
def get_num_unfinished_requests(self) -> int:
return self.scheduler.get_num_unfinished_seq_groups()
@@ -243,13 +241,13 @@ class LLMServer:
all_outputs = []
for worker in self.workers:
executor = getattr(worker, method)
if self.parallel_config.use_ray:
if self.parallel_config.worker_use_ray:
executor = executor.remote
output = executor(*args, **kwargs)
all_outputs.append(output)
if self.parallel_config.use_ray:
if self.parallel_config.worker_use_ray:
all_outputs = ray.get(all_outputs)
if get_all_outputs: