Usage Stats Collection (#2852)

This commit is contained in:
yhu422
2024-03-28 22:16:12 -07:00
committed by GitHub
parent 7bc94a0fdd
commit d8658c8cc1
15 changed files with 362 additions and 24 deletions

View File

@@ -16,6 +16,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = int(
@@ -319,9 +320,12 @@ class AsyncLLMEngine:
self._errored_with: Optional[BaseException] = None
@classmethod
def from_engine_args(cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True) -> "AsyncLLMEngine":
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
@@ -341,14 +345,17 @@ class AsyncLLMEngine:
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
engine = cls(
parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
)
return engine
@property