[Quality] Add code formatter and linter (#326)
This commit is contained in:
@@ -11,7 +11,7 @@ from vllm.sampling_params import SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
@@ -35,8 +35,13 @@ class AsyncLLMEngine:
|
||||
log_requests: Whether to log the requests.
|
||||
*args, *kwargs: Arguments for LLMEngine.
|
||||
"""
|
||||
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
|
||||
log_requests: bool = True, *args, **kwargs) -> None:
|
||||
|
||||
def __init__(self,
|
||||
worker_use_ray: bool,
|
||||
engine_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.engine_use_ray = engine_use_ray
|
||||
self.log_requests = log_requests
|
||||
@@ -76,12 +81,11 @@ class AsyncLLMEngine:
|
||||
self.request_events[request_id].set()
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
) -> RequestOutput:
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
@@ -117,14 +121,17 @@ class AsyncLLMEngine:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(
|
||||
request_id, prompt, sampling_params,
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
else:
|
||||
self.engine.add_request(
|
||||
request_id, prompt, sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
self.engine.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
# The vLLM engine does not have a background loop that keeps
|
||||
# processing incoming requests. Therefore, we need to keep kicking
|
||||
@@ -200,7 +207,8 @@ class AsyncLLMEngine:
|
||||
self.kicking_request_id = None
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
@@ -211,8 +219,9 @@ class AsyncLLMEngine:
|
||||
# Create the async LLM engine.
|
||||
engine = cls(engine_args.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
not engine_args.disable_log_requests,
|
||||
*engine_configs,
|
||||
distributed_init_method, devices,
|
||||
distributed_init_method,
|
||||
devices,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
||||
|
||||
Reference in New Issue
Block a user