[misc] Add Torch profiler support (#7451)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
@@ -305,6 +305,26 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!")
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile():
|
||||
logger.info("Starting profiler...")
|
||||
await async_engine_client.start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile():
|
||||
logger.info("Stopping profiler...")
|
||||
await async_engine_client.stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(router)
|
||||
|
||||
@@ -46,6 +46,8 @@ class RPCUtilityRequest(Enum):
|
||||
DO_LOG_STATS = 7
|
||||
IS_SERVER_HEALTHY = 8
|
||||
IS_TRACING_ENABLED = 9
|
||||
START_PROFILE = 10
|
||||
STOP_PROFILE = 11
|
||||
|
||||
|
||||
RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
|
||||
|
||||
@@ -400,3 +400,17 @@ class AsyncEngineRPCClient:
|
||||
**kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
raise NotImplementedError(
|
||||
"Embeddings not supported with multiprocessing backend")
|
||||
|
||||
async def start_profile(self) -> None:
|
||||
"""Start profiling the engine"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.START_PROFILE,
|
||||
error_message="RPCRequest START_PROFILE failed.")
|
||||
|
||||
async def stop_profile(self) -> None:
|
||||
"""Stop profiling the engine"""
|
||||
|
||||
await self._send_one_way_rpc_request(
|
||||
request=RPCUtilityRequest.STOP_PROFILE,
|
||||
error_message="RPCRequest STOP_PROFILE failed.")
|
||||
@@ -124,6 +124,26 @@ class AsyncEngineRPCServer:
|
||||
except Exception as e:
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||
|
||||
async def start_profile(self, identity):
|
||||
logger.info("Starting profiler...")
|
||||
await self.engine.start_profile()
|
||||
logger.info("Profiler started.")
|
||||
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
|
||||
async def stop_profile(self, identity):
|
||||
logger.info("Stopping profiler...")
|
||||
await self.engine.stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
|
||||
def _make_handler_coro(self, identity,
|
||||
message) -> Coroutine[Any, Any, Never]:
|
||||
"""Route the zmq message to the handler coroutine."""
|
||||
@@ -153,6 +173,10 @@ class AsyncEngineRPCServer:
|
||||
return self.check_health(identity)
|
||||
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
|
||||
return self.is_tracing_enabled(identity)
|
||||
elif request == RPCUtilityRequest.START_PROFILE:
|
||||
return self.start_profile(identity)
|
||||
elif request == RPCUtilityRequest.STOP_PROFILE:
|
||||
return self.stop_profile(identity)
|
||||
else:
|
||||
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user