[Typing] Mypy typing part 2 (#4043)

Co-authored-by: SangBin Cho <sangcho@sangcho-LT93GQWG9C.local>
This commit is contained in:
SangBin Cho
2024-04-18 09:28:43 +09:00
committed by GitHub
parent a53222544c
commit 533d2a1f39
20 changed files with 180 additions and 126 deletions

View File

@@ -2,8 +2,8 @@ import asyncio
import os
import time
from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
Set, Tuple, Type, Union)
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
@@ -52,7 +52,7 @@ class AsyncStream:
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._queue = asyncio.Queue()
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
def put(self, item: Union[RequestOutput, Exception]) -> None:
@@ -312,15 +312,17 @@ class AsyncLLMEngine:
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)
self.background_loop = None
self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self._background_loop_unshielded = None
self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
self.start_engine_loop = start_engine_loop
self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None
# Lazy initialized fields
self._request_tracker: RequestTracker
@classmethod
def from_engine_args(
cls,
@@ -361,11 +363,13 @@ class AsyncLLMEngine:
@property
def is_running(self) -> bool:
return (self.background_loop is not None
and self._background_loop_unshielded is not None
and not self._background_loop_unshielded.done())
@property
def is_stopped(self) -> bool:
return self.errored or (self.background_loop is not None
return self.errored or (self.background_loop is not None and
self._background_loop_unshielded is not None
and self._background_loop_unshielded.done())
@property
@@ -381,7 +385,7 @@ class AsyncLLMEngine:
async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote()
return await self.engine.get_tokenizer.remote() # type: ignore
else:
return self.engine.get_tokenizer()
@@ -434,7 +438,8 @@ class AsyncLLMEngine:
# TODO: Maybe add add_request_batch to reduce Ray overhead
try:
if self.engine_use_ray:
await self.engine.add_request.remote(**new_request)
await self.engine.add_request.remote( # type: ignore
**new_request)
else:
await self.engine.add_request_async(**new_request)
except ValueError as e:
@@ -449,7 +454,7 @@ class AsyncLLMEngine:
await self._engine_abort(finished_requests)
if self.engine_use_ray:
request_outputs = await self.engine.step.remote()
request_outputs = await self.engine.step.remote() # type: ignore
else:
request_outputs = await self.engine.step_async()
@@ -462,7 +467,7 @@ class AsyncLLMEngine:
async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids)
await self.engine.abort_request.remote(request_ids) # type: ignore
else:
self.engine.abort_request(request_ids)
@@ -525,11 +530,12 @@ class AsyncLLMEngine:
arrival_time = time.time()
if self.engine_use_ray:
prompt_token_ids = await self.engine.encode_request_async.remote(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
prompt_token_ids = await (
self.engine.encode_request_async.remote( # type: ignore
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request))
else:
prompt_token_ids = await self.engine.encode_request_async(
request_id=request_id,
@@ -676,13 +682,13 @@ class AsyncLLMEngine:
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_model_config.remote()
return await self.engine.get_model_config.remote() # type: ignore
else:
return self.engine.get_model_config()
async def do_log_stats(self) -> None:
if self.engine_use_ray:
await self.engine.do_log_stats.remote()
await self.engine.do_log_stats.remote() # type: ignore
else:
self.engine.do_log_stats()
@@ -695,7 +701,7 @@ class AsyncLLMEngine:
if self.engine_use_ray:
try:
await self.engine.check_health.remote()
await self.engine.check_health.remote() # type: ignore
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else: