diff --git a/vllm/distributed/device_communicators/flashinfer_all_reduce.py b/vllm/distributed/device_communicators/flashinfer_all_reduce.py index 1152277f7..66e089182 100644 --- a/vllm/distributed/device_communicators/flashinfer_all_reduce.py +++ b/vllm/distributed/device_communicators/flashinfer_all_reduce.py @@ -3,6 +3,8 @@ import atexit +import os +import random import threading import torch @@ -67,15 +69,20 @@ def initialize_fi_ar_workspace( backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND comm_backend = TorchDistBackend(group=group) - _fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace( - backend=backend, - world_size=world_size, - rank=rank, - max_token_num=max_token_num, - hidden_dim=hidden_dim, - dtype=dtype, - comm_backend=comm_backend, - ) + rng_state = random.getstate() + try: + random.seed(int.from_bytes(os.urandom(16), byteorder="big")) + _fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + comm_backend=comm_backend, + ) + finally: + random.setstate(rng_state) assert _fi_ar_workspace is not None logger.debug( "Initialized FlashInfer All Reduce workspace: backend=%s, "