[Typing] Mypy typing part 2 (#4043)
Co-authored-by: SangBin Cho <sangcho@sangcho-LT93GQWG9C.local>
This commit is contained in:
@@ -2,8 +2,8 @@ import asyncio
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
|
||||
Set, Tuple, Type, Union)
|
||||
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
|
||||
Optional, Set, Tuple, Type, Union)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@@ -52,7 +52,7 @@ class AsyncStream:
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
self._queue = asyncio.Queue()
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item: Union[RequestOutput, Exception]) -> None:
|
||||
@@ -312,15 +312,17 @@ class AsyncLLMEngine:
|
||||
self.max_log_len = max_log_len
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
self.background_loop = None
|
||||
self.background_loop: Optional[asyncio.Future] = None
|
||||
# We need to keep a reference to unshielded
|
||||
# task as well to prevent it from being garbage
|
||||
# collected
|
||||
self._background_loop_unshielded = None
|
||||
self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
|
||||
self.start_engine_loop = start_engine_loop
|
||||
self._request_tracker: Optional[RequestTracker] = None
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
# Lazy initialized fields
|
||||
self._request_tracker: RequestTracker
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(
|
||||
cls,
|
||||
@@ -361,11 +363,13 @@ class AsyncLLMEngine:
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return (self.background_loop is not None
|
||||
and self._background_loop_unshielded is not None
|
||||
and not self._background_loop_unshielded.done())
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self.errored or (self.background_loop is not None
|
||||
return self.errored or (self.background_loop is not None and
|
||||
self._background_loop_unshielded is not None
|
||||
and self._background_loop_unshielded.done())
|
||||
|
||||
@property
|
||||
@@ -381,7 +385,7 @@ class AsyncLLMEngine:
|
||||
|
||||
async def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_tokenizer.remote()
|
||||
return await self.engine.get_tokenizer.remote() # type: ignore
|
||||
else:
|
||||
return self.engine.get_tokenizer()
|
||||
|
||||
@@ -434,7 +438,8 @@ class AsyncLLMEngine:
|
||||
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
||||
try:
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(**new_request)
|
||||
await self.engine.add_request.remote( # type: ignore
|
||||
**new_request)
|
||||
else:
|
||||
await self.engine.add_request_async(**new_request)
|
||||
except ValueError as e:
|
||||
@@ -449,7 +454,7 @@ class AsyncLLMEngine:
|
||||
await self._engine_abort(finished_requests)
|
||||
|
||||
if self.engine_use_ray:
|
||||
request_outputs = await self.engine.step.remote()
|
||||
request_outputs = await self.engine.step.remote() # type: ignore
|
||||
else:
|
||||
request_outputs = await self.engine.step_async()
|
||||
|
||||
@@ -462,7 +467,7 @@ class AsyncLLMEngine:
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
await self.engine.abort_request.remote(request_ids)
|
||||
await self.engine.abort_request.remote(request_ids) # type: ignore
|
||||
else:
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
@@ -525,11 +530,12 @@ class AsyncLLMEngine:
|
||||
arrival_time = time.time()
|
||||
|
||||
if self.engine_use_ray:
|
||||
prompt_token_ids = await self.engine.encode_request_async.remote(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request)
|
||||
prompt_token_ids = await (
|
||||
self.engine.encode_request_async.remote( # type: ignore
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request))
|
||||
else:
|
||||
prompt_token_ids = await self.engine.encode_request_async(
|
||||
request_id=request_id,
|
||||
@@ -676,13 +682,13 @@ class AsyncLLMEngine:
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_model_config.remote()
|
||||
return await self.engine.get_model_config.remote() # type: ignore
|
||||
else:
|
||||
return self.engine.get_model_config()
|
||||
|
||||
async def do_log_stats(self) -> None:
|
||||
if self.engine_use_ray:
|
||||
await self.engine.do_log_stats.remote()
|
||||
await self.engine.do_log_stats.remote() # type: ignore
|
||||
else:
|
||||
self.engine.do_log_stats()
|
||||
|
||||
@@ -695,7 +701,7 @@ class AsyncLLMEngine:
|
||||
|
||||
if self.engine_use_ray:
|
||||
try:
|
||||
await self.engine.check_health.remote()
|
||||
await self.engine.check_health.remote() # type: ignore
|
||||
except ray.exceptions.RayActorError as e:
|
||||
raise RuntimeError("Engine is dead.") from e
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user