[Bug] Fix FlashInfer MNNVL socket collisions under concurrent vLLM jobs (#36674)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-17 15:19:52 -04:00
committed by GitHub
parent 68f783a727
commit bdb903bb5f

View File

@@ -3,6 +3,8 @@
import atexit
import os
import random
import threading
import torch
@@ -67,15 +69,20 @@ def initialize_fi_ar_workspace(
backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND
comm_backend = TorchDistBackend(group=group)
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend=backend,
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
)
rng_state = random.getstate()
try:
random.seed(int.from_bytes(os.urandom(16), byteorder="big"))
_fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend=backend,
world_size=world_size,
rank=rank,
max_token_num=max_token_num,
hidden_dim=hidden_dim,
dtype=dtype,
comm_backend=comm_backend,
)
finally:
random.setstate(rng_state)
assert _fi_ar_workspace is not None
logger.debug(
"Initialized FlashInfer All Reduce workspace: backend=%s, "