Compare commits

..

4 Commits

Author SHA1 Message Date
Zhuohan Li
83658c8ace Bump up version to 0.1.1 (#204) 2023-06-22 15:33:32 +08:00
Zhuohan Li
1d24ccb96c [Fix] Better error message when there is OOM during cache initialization (#203) 2023-06-22 15:30:06 +08:00
Woosuk Kwon
14f0b39cda [Bugfix] Fix a bug in RequestOutput.finished (#202) 2023-06-22 00:17:24 -07:00
Zhuohan Li
2e0d314384 fix-ray (#193) 2023-06-22 00:21:41 +08:00
6 changed files with 19 additions and 10 deletions

View File

@@ -30,7 +30,7 @@ def main(args: argparse.Namespace):
request_outputs = engine.step() request_outputs = engine.step()
for request_output in request_outputs: for request_output in request_outputs:
if request_output.finished(): if request_output.finished:
print(request_output) print(request_output)
if not (engine.has_unfinished_requests() or test_prompts): if not (engine.has_unfinished_requests() or test_prompts):

View File

@@ -6,7 +6,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
__version__ = "0.1.0" __version__ = "0.1.1"
__all__ = [ __all__ = [
"LLM", "LLM",

View File

@@ -154,7 +154,7 @@ class AsyncLLMEngine:
yield request_output yield request_output
# Once finished, release the resources of the sequence group. # Once finished, release the resources of the sequence group.
if request_output.finished(): if request_output.finished:
if self.log_requests: if self.log_requests:
logger.info(f"Finished request {request_id}.") logger.info(f"Finished request {request_id}.")

View File

@@ -87,7 +87,7 @@ class LLMEngine:
worker_cls = ray.remote( worker_cls = ray.remote(
num_cpus=0, num_cpus=0,
num_gpus=1, num_gpus=1,
resources={node_resource: 1e-5}, resources={node_resource: 1e-3},
)(worker_cls).remote )(worker_cls).remote
worker = worker_cls( worker = worker_cls(
@@ -127,6 +127,12 @@ class LLMEngine:
# FIXME(woosuk): Change to debug log. # FIXME(woosuk): Change to debug log.
logger.info(f'# GPU blocks: {num_gpu_blocks}, ' logger.info(f'# GPU blocks: {num_gpu_blocks}, '
f'# CPU blocks: {num_cpu_blocks}') f'# CPU blocks: {num_cpu_blocks}')
if num_gpu_blocks <= 0 or num_cpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks

View File

@@ -133,7 +133,7 @@ class LLM:
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step() step_outputs = self.llm_engine.step()
for output in step_outputs: for output in step_outputs:
if output.finished(): if output.finished:
outputs.append(output) outputs.append(output)
if use_tqdm: if use_tqdm:
pbar.update(1) pbar.update(1)

View File

@@ -53,6 +53,7 @@ class RequestOutput:
prompt: The prompt string of the request. prompt: The prompt string of the request.
prompt_token_ids: The token IDs of the prompt. prompt_token_ids: The token IDs of the prompt.
outputs: The output sequences of the request. outputs: The output sequences of the request.
finished: Whether the whole request is finished.
""" """
def __init__( def __init__(
self, self,
@@ -60,11 +61,13 @@ class RequestOutput:
prompt: str, prompt: str,
prompt_token_ids: List[int], prompt_token_ids: List[int],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
finished: bool,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.prompt = prompt self.prompt = prompt
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.outputs = outputs self.outputs = outputs
self.finished = finished
@classmethod @classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
@@ -95,13 +98,13 @@ class RequestOutput:
# Every sequence in the sequence group should have the same prompt. # Every sequence in the sequence group should have the same prompt.
prompt = top_n_seqs[0].prompt prompt = top_n_seqs[0].prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs) finished = seq_group.is_finished()
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
finished)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, " return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, " f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_token_ids={self.prompt_token_ids}, "
f"outputs={self.outputs})") f"outputs={self.outputs}, "
f"finished={self.finished})")
def finished(self) -> bool:
return all(output.finished() for output in self.outputs)