[Quality] Add CI for formatting (#343)
This commit is contained in:
@@ -2,6 +2,7 @@ import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.ray_utils import initialize_cluster, ray
|
||||
@@ -206,6 +207,13 @@ class AsyncLLMEngine:
|
||||
self.is_engine_running = False
|
||||
self.kicking_request_id = None
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_model_config.remote()
|
||||
else:
|
||||
return self.engine.get_model_config()
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
|
||||
|
||||
@@ -210,6 +210,10 @@ class LLMEngine:
|
||||
"""
|
||||
self.scheduler.abort_seq_group(request_id)
|
||||
|
||||
def get_model_config(self) -> ModelConfig:
|
||||
"""Gets the model configuration."""
|
||||
return self.model_config
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Gets the number of unfinished requests."""
|
||||
return self.scheduler.get_num_unfinished_seq_groups()
|
||||
|
||||
Reference in New Issue
Block a user