Introduce LLM class for offline inference (#115)
This commit is contained in:
62
cacheflow/entrypoints/llm.py
Normal file
62
cacheflow/entrypoints/llm.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.server.arg_utils import ServerArgs
|
||||
from cacheflow.server.llm_server import LLMServer
|
||||
from cacheflow.utils import Counter
|
||||
|
||||
|
||||
class LLM:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "default",
|
||||
seed: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
server_args = ServerArgs(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_server = LLMServer.from_server_args(server_args)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
use_tqdm: bool = True,
|
||||
) -> List[RequestOutput]:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
pbar = tqdm(total=len(prompts), desc="Processed prompts")
|
||||
|
||||
# Add requests to the server.
|
||||
for prompt in prompts:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_server.add_request(request_id, prompt, sampling_params)
|
||||
|
||||
# Run the server.
|
||||
outputs: List[RequestOutput] = []
|
||||
while self.llm_server.has_unfinished_requests():
|
||||
step_outputs = self.llm_server.step()
|
||||
for output in step_outputs:
|
||||
if output.done:
|
||||
outputs.append(output)
|
||||
if use_tqdm:
|
||||
pbar.update(1)
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
return outputs
|
||||
Reference in New Issue
Block a user