[bugfix][async scheduling] fix extra cuda context in device 0 with EP/DP (#37449)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -597,17 +597,6 @@ class WorkerProc:
|
||||
wrapper.init_worker(all_kwargs)
|
||||
self.worker = wrapper
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.use_async_scheduling = scheduler_config.async_scheduling
|
||||
if self.use_async_scheduling:
|
||||
self.async_output_queue: queue.Queue = queue.Queue()
|
||||
self.async_output_copy_thread = Thread(
|
||||
target=self.async_output_busy_loop,
|
||||
daemon=True,
|
||||
name="WorkerAsyncOutputCopy",
|
||||
)
|
||||
self.async_output_copy_thread.start()
|
||||
|
||||
self.setup_proc_title_and_log_prefix(
|
||||
enable_ep=vllm_config.parallel_config.enable_expert_parallel
|
||||
)
|
||||
@@ -622,6 +611,17 @@ class WorkerProc:
|
||||
)
|
||||
self.worker.load_model()
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.use_async_scheduling = scheduler_config.async_scheduling
|
||||
if self.use_async_scheduling:
|
||||
self.async_output_queue: queue.Queue = queue.Queue()
|
||||
self.async_output_copy_thread = Thread(
|
||||
target=self.async_output_busy_loop,
|
||||
daemon=True,
|
||||
name="WorkerAsyncOutputCopy",
|
||||
)
|
||||
self.async_output_copy_thread.start()
|
||||
|
||||
# Set block size based on the attention backends
|
||||
current_platform.update_block_size_for_backend(vllm_config)
|
||||
|
||||
@@ -911,6 +911,18 @@ class WorkerProc:
|
||||
|
||||
def async_output_busy_loop(self):
|
||||
"""Entrypoint for the thread which handles outputs asynchronously."""
|
||||
|
||||
# set device to the worker device for the thread.
|
||||
# a thread will not inherit the context of the main thread.
|
||||
# when calling any cuda runtime functions, it will implicitly
|
||||
# create a new cuda context on device 0, consuming extra memory.
|
||||
# here we set the device to the worker device for the thread,
|
||||
# enforcing the context to be the same as the main thread.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if hasattr(self.worker, "device"):
|
||||
current_platform.set_device(self.worker.device)
|
||||
|
||||
while True:
|
||||
output = self.async_output_queue.get()
|
||||
self.enqueue_output(output)
|
||||
|
||||
Reference in New Issue
Block a user