[Bug] Fix FlashInfer MNNVL socket collisions under concurrent vLLM jobs (#36674)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
|
||||
|
||||
import atexit
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
|
||||
import torch
|
||||
@@ -67,15 +69,20 @@ def initialize_fi_ar_workspace(
|
||||
|
||||
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
|
||||
comm_backend = TorchDistBackend(group=group)
|
||||
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
max_token_num=max_token_num,
|
||||
hidden_dim=hidden_dim,
|
||||
dtype=dtype,
|
||||
comm_backend=comm_backend,
|
||||
)
|
||||
rng_state = random.getstate()
|
||||
try:
|
||||
random.seed(int.from_bytes(os.urandom(16), byteorder="big"))
|
||||
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend=backend,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
max_token_num=max_token_num,
|
||||
hidden_dim=hidden_dim,
|
||||
dtype=dtype,
|
||||
comm_backend=comm_backend,
|
||||
)
|
||||
finally:
|
||||
random.setstate(rng_state)
|
||||
assert _fi_ar_workspace is not None
|
||||
logger.debug(
|
||||
"Initialized FlashInfer All Reduce workspace: backend=%s, "
|
||||
|
||||
Reference in New Issue
Block a user