Fix various issues of async servers (#135)
This commit is contained in:
@@ -1,11 +1,6 @@
|
||||
import time
|
||||
from typing import Any, List, Optional
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
ray = None
|
||||
|
||||
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from cacheflow.core.scheduler import Scheduler
|
||||
@@ -13,7 +8,7 @@ from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.ray_utils import initialize_cluster
|
||||
from cacheflow.server.ray_utils import ray, initialize_cluster
|
||||
from cacheflow.server.tokenizer_utils import (get_tokenizer,
|
||||
detokenize_incrementally)
|
||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
@@ -62,7 +57,7 @@ class LLMServer:
|
||||
assert len(stage_devices) == 1, "Only support one stage for now."
|
||||
for rank, node_resource, _ in stage_devices[0]:
|
||||
worker_cls = Worker
|
||||
if self.parallel_config.use_ray:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
worker_cls = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=1,
|
||||
@@ -152,6 +147,9 @@ class LLMServer:
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
|
||||
def abort_request(self, request_id: str) -> None:
|
||||
self.scheduler.abort_seq_group(request_id)
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
return self.scheduler.get_num_unfinished_seq_groups()
|
||||
|
||||
@@ -243,13 +241,13 @@ class LLMServer:
|
||||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
executor = getattr(worker, method)
|
||||
if self.parallel_config.use_ray:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
executor = executor.remote
|
||||
|
||||
output = executor(*args, **kwargs)
|
||||
all_outputs.append(output)
|
||||
|
||||
if self.parallel_config.use_ray:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
all_outputs = ray.get(all_outputs)
|
||||
|
||||
if get_all_outputs:
|
||||
|
||||
Reference in New Issue
Block a user