Rename servers to engines (#152)
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from cacheflow.engine.arg_utils import EngineArgs
|
||||
from cacheflow.engine.llm_engine import LLMEngine
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.llm_server import LLMEngine
|
||||
from cacheflow.utils import Counter
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class LLM:
|
||||
|
||||
NOTE: This class is intended to be used for offline inference. For online
|
||||
serving, use the `AsyncLLMEngine` class instead.
|
||||
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
|
||||
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
||||
|
||||
Args:
|
||||
model: The name or path of a HuggingFace Transformers model.
|
||||
@@ -45,20 +45,20 @@ class LLM:
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
server_args = ServerArgs(
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_server = LLMEngine.from_server_args(server_args)
|
||||
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
return self.llm_server.tokenizer
|
||||
return self.llm_engine.tokenizer
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@@ -99,7 +99,7 @@ class LLM:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
# Add requests to the server.
|
||||
# Add requests to the engine.
|
||||
if prompts is not None:
|
||||
num_requests = len(prompts)
|
||||
else:
|
||||
@@ -111,7 +111,7 @@ class LLM:
|
||||
else:
|
||||
token_ids = prompt_token_ids[i]
|
||||
self._add_request(prompt, sampling_params, token_ids)
|
||||
return self._run_server(use_tqdm)
|
||||
return self._run_engine(use_tqdm)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
@@ -120,18 +120,18 @@ class LLM:
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_server.add_request(request_id, prompt, sampling_params,
|
||||
self.llm_engine.add_request(request_id, prompt, sampling_params,
|
||||
prompt_token_ids)
|
||||
|
||||
def _run_server(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_server.get_num_unfinished_requests()
|
||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
||||
# Run the server.
|
||||
# Run the engine.
|
||||
outputs: List[RequestOutput] = []
|
||||
while self.llm_server.has_unfinished_requests():
|
||||
step_outputs = self.llm_server.step()
|
||||
while self.llm_engine.has_unfinished_requests():
|
||||
step_outputs = self.llm_engine.step()
|
||||
for output in step_outputs:
|
||||
if output.finished():
|
||||
outputs.append(output)
|
||||
|
||||
Reference in New Issue
Block a user