Rename servers to engines (#152)

This commit is contained in:
Zhuohan Li
2023-06-17 17:25:21 +08:00
committed by GitHub
parent bab8f3dd0d
commit e5464ee484
15 changed files with 165 additions and 174 deletions

View File

@@ -1,12 +1,12 @@
from typing import List, Optional, Union
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from cacheflow.engine.arg_utils import EngineArgs
from cacheflow.engine.llm_engine import LLMEngine
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMEngine
from cacheflow.utils import Counter
@@ -21,7 +21,7 @@ class LLM:
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
@@ -45,20 +45,20 @@ class LLM:
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
server_args = ServerArgs(
engine_args = EngineArgs(
model=model,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
seed=seed,
**kwargs,
)
self.llm_server = LLMEngine.from_server_args(server_args)
self.llm_engine = LLMEngine.from_engine_args(engine_args)
self.request_counter = Counter()
def get_tokenizer(
self,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_server.tokenizer
return self.llm_engine.tokenizer
def generate(
self,
@@ -99,7 +99,7 @@ class LLM:
# Use default sampling params.
sampling_params = SamplingParams()
# Add requests to the server.
# Add requests to the engine.
if prompts is not None:
num_requests = len(prompts)
else:
@@ -111,7 +111,7 @@ class LLM:
else:
token_ids = prompt_token_ids[i]
self._add_request(prompt, sampling_params, token_ids)
return self._run_server(use_tqdm)
return self._run_engine(use_tqdm)
def _add_request(
self,
@@ -120,18 +120,18 @@ class LLM:
prompt_token_ids: Optional[List[int]],
) -> None:
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params,
self.llm_engine.add_request(request_id, prompt, sampling_params,
prompt_token_ids)
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_server.get_num_unfinished_requests()
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the server.
# Run the engine.
outputs: List[RequestOutput] = []
while self.llm_server.has_unfinished_requests():
step_outputs = self.llm_server.step()
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished():
outputs.append(output)