Introduce LLM class for offline inference (#115)

This commit is contained in:
Woosuk Kwon
2023-05-21 17:04:18 -07:00
committed by GitHub
parent f746ced08d
commit 655a5e48df
9 changed files with 222 additions and 81 deletions

View File

@@ -12,6 +12,8 @@ from cacheflow.core.scheduler import Scheduler
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.ray_utils import initialize_cluster
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
from cacheflow.utils import Counter
@@ -30,7 +32,7 @@ class LLMServer:
scheduler_config: SchedulerConfig,
distributed_init_method: str,
stage_devices: List[List[Any]],
log_stats: bool = True,
log_stats: bool,
) -> None:
logger.info(
"Initializing an LLM server with config: "
@@ -90,7 +92,7 @@ class LLMServer:
get_all_outputs=True,
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space,
cpu_swap_space=self.cache_config.swap_space_bytes,
)
# Since we use a shared centralized controller, we take the minimum
@@ -107,6 +109,18 @@ class LLMServer:
# Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config)
@classmethod
def from_server_args(cls, server_args: ServerArgs) -> "LLMServer":
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM server.
server = cls(*server_configs, distributed_init_method, devices,
log_stats=not server_args.disable_log_stats)
return server
def add_request(
self,
request_id: str,