Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1117aa1411 | ||
|
|
080438477f | ||
|
|
4b5bcf8906 |
@@ -40,8 +40,7 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args,
|
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
|
||||||
start_engine_loop=False)
|
|
||||||
vllm.entrypoints.api_server.engine = engine
|
vllm.entrypoints.api_server.engine = engine
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
app,
|
app,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
|
|||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
__version__ = "0.1.5"
|
__version__ = "0.1.6"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LLM",
|
"LLM",
|
||||||
|
|||||||
@@ -230,6 +230,8 @@ class AsyncLLMEngine:
|
|||||||
async frontend will be executed in a separate process as the
|
async frontend will be executed in a separate process as the
|
||||||
model workers.
|
model workers.
|
||||||
log_requests: Whether to log the requests.
|
log_requests: Whether to log the requests.
|
||||||
|
start_engine_loop: If True, the background task to run the engine
|
||||||
|
will be automatically started in the generate call.
|
||||||
*args, *kwargs: Arguments for LLMEngine.
|
*args, *kwargs: Arguments for LLMEngine.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -240,7 +242,7 @@ class AsyncLLMEngine:
|
|||||||
engine_use_ray: bool,
|
engine_use_ray: bool,
|
||||||
*args,
|
*args,
|
||||||
log_requests: bool = True,
|
log_requests: bool = True,
|
||||||
start_engine_loop: bool = False,
|
start_engine_loop: bool = True,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
self.worker_use_ray = worker_use_ray
|
self.worker_use_ray = worker_use_ray
|
||||||
self.engine_use_ray = engine_use_ray
|
self.engine_use_ray = engine_use_ray
|
||||||
@@ -249,8 +251,7 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
self.request_tracker: RequestTracker = RequestTracker()
|
self.request_tracker: RequestTracker = RequestTracker()
|
||||||
self.background_loop = None
|
self.background_loop = None
|
||||||
if start_engine_loop:
|
self.start_engine_loop = start_engine_loop
|
||||||
self.start_background_loop()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
@@ -330,11 +331,14 @@ class AsyncLLMEngine:
|
|||||||
f"prompt token ids: {prompt_token_ids}.")
|
f"prompt token ids: {prompt_token_ids}.")
|
||||||
|
|
||||||
if not self.is_running:
|
if not self.is_running:
|
||||||
raise AsyncEngineDeadError(
|
if self.start_engine_loop:
|
||||||
"Background loop is not running. If it was running, "
|
self.start_background_loop()
|
||||||
"inspect the output to find the stacktrace of the "
|
else:
|
||||||
"error that caused the background loop to stop "
|
raise AsyncEngineDeadError(
|
||||||
"(AsyncEngineDeadError).")
|
"Background loop is not running. If it was running, "
|
||||||
|
"inspect the output to find the stacktrace of the "
|
||||||
|
"error that caused the background loop to stop "
|
||||||
|
"(AsyncEngineDeadError).")
|
||||||
|
|
||||||
stream = self.request_tracker.add_request(
|
stream = self.request_tracker.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
@@ -426,7 +430,7 @@ class AsyncLLMEngine:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_args(cls,
|
def from_engine_args(cls,
|
||||||
engine_args: AsyncEngineArgs,
|
engine_args: AsyncEngineArgs,
|
||||||
start_engine_loop: bool = False) -> "AsyncLLMEngine":
|
start_engine_loop: bool = True) -> "AsyncLLMEngine":
|
||||||
"""Creates an async LLM engine from the engine arguments."""
|
"""Creates an async LLM engine from the engine arguments."""
|
||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
engine_configs = engine_args.create_engine_configs()
|
engine_configs = engine_args.create_engine_configs()
|
||||||
|
|||||||
@@ -32,9 +32,6 @@ async def generate(request: Request) -> Response:
|
|||||||
sampling_params = SamplingParams(**request_dict)
|
sampling_params = SamplingParams(**request_dict)
|
||||||
request_id = random_uuid()
|
request_id = random_uuid()
|
||||||
|
|
||||||
if not engine.is_running:
|
|
||||||
engine.start_background_loop()
|
|
||||||
|
|
||||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||||
|
|
||||||
# Streaming case
|
# Streaming case
|
||||||
@@ -80,8 +77,7 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
start_engine_loop=False)
|
|
||||||
|
|
||||||
uvicorn.run(app,
|
uvicorn.run(app,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
|
|||||||
@@ -192,9 +192,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||||||
"""
|
"""
|
||||||
logger.info(f"Received chat completion request: {request}")
|
logger.info(f"Received chat completion request: {request}")
|
||||||
|
|
||||||
if not engine.is_running:
|
|
||||||
engine.start_background_loop()
|
|
||||||
|
|
||||||
error_check_ret = await check_model(request)
|
error_check_ret = await check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
@@ -367,9 +364,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||||||
"""
|
"""
|
||||||
logger.info(f"Received completion request: {request}")
|
logger.info(f"Received completion request: {request}")
|
||||||
|
|
||||||
if not engine.is_running:
|
|
||||||
engine.start_background_loop()
|
|
||||||
|
|
||||||
error_check_ret = await check_model(request)
|
error_check_ret = await check_model(request)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
@@ -627,8 +621,7 @@ if __name__ == "__main__":
|
|||||||
served_model = args.model
|
served_model = args.model
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine = AsyncLLMEngine.from_engine_args(engine_args,
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||||
start_engine_loop=False)
|
|
||||||
engine_model_config = asyncio.run(engine.get_model_config())
|
engine_model_config = asyncio.run(engine.get_model_config())
|
||||||
max_model_len = engine_model_config.get_max_model_len()
|
max_model_len = engine_model_config.get_max_model_len()
|
||||||
|
|
||||||
|
|||||||
@@ -259,8 +259,9 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
self.is_neox_style = is_neox_style
|
self.is_neox_style = is_neox_style
|
||||||
|
|
||||||
# Create the cos and sin cache.
|
# Create the cos and sin cache.
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
|
inv_freq = 1.0 / (base**(
|
||||||
t = torch.arange(max_position).float()
|
torch.arange(0, rotary_dim, 2, device="cuda") / rotary_dim))
|
||||||
|
t = torch.arange(max_position, device="cuda").float()
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
|
|||||||
Reference in New Issue
Block a user