Compare commits

..

3 Commits

Author SHA1 Message Date
Zhuohan Li
1117aa1411 Bump up the version to v0.1.6 (#989)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.8) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9) (push) Has been cancelled
2023-09-08 00:07:46 -07:00
Antoni Baum
080438477f Start background task in AsyncLLMEngine.generate (#988)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-09-08 00:03:39 -07:00
Robert Irvine
4b5bcf8906 faster startup of vLLM (#982)
* update

---------

Co-authored-by: Robert Irvine <robert@seamlessml.com>
2023-09-08 14:48:54 +09:00
6 changed files with 20 additions and 27 deletions

View File

@@ -40,8 +40,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args, engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
start_engine_loop=False)
vllm.entrypoints.api_server.engine = engine vllm.entrypoints.api_server.engine = engine
uvicorn.run( uvicorn.run(
app, app,

View File

@@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
__version__ = "0.1.5" __version__ = "0.1.6"
__all__ = [ __all__ = [
"LLM", "LLM",

View File

@@ -230,6 +230,8 @@ class AsyncLLMEngine:
async frontend will be executed in a separate process as the async frontend will be executed in a separate process as the
model workers. model workers.
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args, *kwargs: Arguments for LLMEngine. *args, *kwargs: Arguments for LLMEngine.
""" """
@@ -240,7 +242,7 @@ class AsyncLLMEngine:
engine_use_ray: bool, engine_use_ray: bool,
*args, *args,
log_requests: bool = True, log_requests: bool = True,
start_engine_loop: bool = False, start_engine_loop: bool = True,
**kwargs) -> None: **kwargs) -> None:
self.worker_use_ray = worker_use_ray self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray self.engine_use_ray = engine_use_ray
@@ -249,8 +251,7 @@ class AsyncLLMEngine:
self.request_tracker: RequestTracker = RequestTracker() self.request_tracker: RequestTracker = RequestTracker()
self.background_loop = None self.background_loop = None
if start_engine_loop: self.start_engine_loop = start_engine_loop
self.start_background_loop()
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@@ -330,6 +331,9 @@ class AsyncLLMEngine:
f"prompt token ids: {prompt_token_ids}.") f"prompt token ids: {prompt_token_ids}.")
if not self.is_running: if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
else:
raise AsyncEngineDeadError( raise AsyncEngineDeadError(
"Background loop is not running. If it was running, " "Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the " "inspect the output to find the stacktrace of the "
@@ -426,7 +430,7 @@ class AsyncLLMEngine:
@classmethod @classmethod
def from_engine_args(cls, def from_engine_args(cls,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
start_engine_loop: bool = False) -> "AsyncLLMEngine": start_engine_loop: bool = True) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()

View File

@@ -32,9 +32,6 @@ async def generate(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
if not engine.is_running:
engine.start_background_loop()
results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case # Streaming case
@@ -80,8 +77,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args, engine = AsyncLLMEngine.from_engine_args(engine_args)
start_engine_loop=False)
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,

View File

@@ -192,9 +192,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
""" """
logger.info(f"Received chat completion request: {request}") logger.info(f"Received chat completion request: {request}")
if not engine.is_running:
engine.start_background_loop()
error_check_ret = await check_model(request) error_check_ret = await check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@@ -367,9 +364,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
""" """
logger.info(f"Received completion request: {request}") logger.info(f"Received completion request: {request}")
if not engine.is_running:
engine.start_background_loop()
error_check_ret = await check_model(request) error_check_ret = await check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@@ -627,8 +621,7 @@ if __name__ == "__main__":
served_model = args.model served_model = args.model
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args, engine = AsyncLLMEngine.from_engine_args(engine_args)
start_engine_loop=False)
engine_model_config = asyncio.run(engine.get_model_config()) engine_model_config = asyncio.run(engine.get_model_config())
max_model_len = engine_model_config.get_max_model_len() max_model_len = engine_model_config.get_max_model_len()

View File

@@ -259,8 +259,9 @@ class PagedAttentionWithRoPE(PagedAttention):
self.is_neox_style = is_neox_style self.is_neox_style = is_neox_style
# Create the cos and sin cache. # Create the cos and sin cache.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(
t = torch.arange(max_position).float() torch.arange(0, rotary_dim, 2, device="cuda") / rotary_dim))
t = torch.arange(max_position, device="cuda").float()
freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()