[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)
This commit is contained in:
@@ -4,9 +4,12 @@ from collections import defaultdict
|
||||
from itertools import islice, repeat
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import msgspec
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||
from vllm.executor.msgspec_utils import encode_hook
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
@@ -60,6 +63,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
||||
self.output_decoder = msgspec.msgpack.Decoder(
|
||||
Optional[List[SamplerOutput]])
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if hasattr(self, "forward_dag") and self.forward_dag is not None:
|
||||
self.forward_dag.teardown()
|
||||
@@ -123,6 +130,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
ray_remote_kwargs)
|
||||
|
||||
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
@@ -304,8 +312,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
outputs = ray.get(self.forward_dag.execute(execute_model_req))
|
||||
return outputs[0]
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
outputs = ray.get(self.forward_dag.execute(serialized_data))
|
||||
output = self.output_decoder.decode(outputs[0])
|
||||
return output
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
@@ -475,9 +485,10 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
|
||||
|
||||
dag_future = await self.forward_dag.execute_async(execute_model_req)
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
dag_future = await self.forward_dag.execute_async(serialized_data)
|
||||
outputs = await dag_future
|
||||
return outputs[0]
|
||||
return self.output_decoder.decode(outputs[0])
|
||||
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user