[Quality] Add code formatter and linter (#326)

This commit is contained in:
Zhuohan Li
2023-07-03 11:31:55 -07:00
committed by GitHub
parent 0ffded812a
commit d6fa1be3a8
47 changed files with 1547 additions and 617 deletions

View File

@@ -11,7 +11,7 @@ from vllm.sampling_params import SamplingParams
logger = init_logger(__name__)
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
class AsyncLLMEngine:
@@ -35,8 +35,13 @@ class AsyncLLMEngine:
log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMEngine.
"""
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
log_requests: bool = True, *args, **kwargs) -> None:
def __init__(self,
worker_use_ray: bool,
engine_use_ray: bool,
*args,
log_requests: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray
self.log_requests = log_requests
@@ -76,12 +81,11 @@ class AsyncLLMEngine:
self.request_events[request_id].set()
async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None
) -> RequestOutput:
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
@@ -117,14 +121,17 @@ class AsyncLLMEngine:
# Add the request into the vLLM engine's waiting queue.
if self.engine_use_ray:
await self.engine.add_request.remote(
request_id, prompt, sampling_params,
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
else:
self.engine.add_request(
request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
self.engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
# The vLLM engine does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
@@ -200,7 +207,8 @@ class AsyncLLMEngine:
self.kicking_request_id = None
@classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
def from_engine_args(cls,
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
@@ -211,8 +219,9 @@ class AsyncLLMEngine:
# Create the async LLM engine.
engine = cls(engine_args.worker_use_ray,
engine_args.engine_use_ray,
not engine_args.disable_log_requests,
*engine_configs,
distributed_init_method, devices,
distributed_init_method,
devices,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats)
return engine