[Kernel] Refactor FlashInfer allreduce for mnnvl backend (#34109)

Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
This commit is contained in:
Hanjie Qiu
2026-02-25 19:17:20 -08:00
committed by GitHub
parent 2aa4140402
commit 71dfce6aa6
7 changed files with 593 additions and 180 deletions

View File

@@ -30,6 +30,9 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
FlashInferAllReduce,
)
from vllm.distributed.device_communicators.pynccl import (
PyNcclCommunicator,
register_nccl_symmetric_ops,
@@ -44,7 +47,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__)
# Default sequence lengths to benchmark
DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192]
DEFAULT_SEQUENCE_LENGTHS = [16, 64, 128, 512, 1024, 2048, 4096, 8192]
# Fixed hidden size and dtype for all benchmarks
HIDDEN_SIZE = 8192
@@ -81,6 +84,7 @@ class CommunicatorBenchmark:
self.symm_mem_comm = None
self.symm_mem_comm_multimem = None
self.symm_mem_comm_two_shot = None
self.fi_ar_comm = None
self._init_communicators()
@@ -161,6 +165,22 @@ class CommunicatorBenchmark:
)
self.symm_mem_comm_two_shot = None
try:
self.fi_ar_comm = FlashInferAllReduce(
group=self.cpu_group,
device=self.device,
)
if not self.fi_ar_comm.disabled:
logger.info("Rank %s: FlashInferAllReduce initialized", self.rank)
else:
logger.info("Rank %s: FlashInferAllReduce disabled", self.rank)
self.fi_ar_comm = None
except Exception as e:
logger.warning(
"Rank %s: Failed to initialize FlashInferAllReduce: %s", self.rank, e
)
self.fi_ar_comm = None
def benchmark_allreduce(
self, sequence_length: int, num_warmup: int, num_trials: int
) -> dict[str, float]:
@@ -180,7 +200,8 @@ class CommunicatorBenchmark:
lambda t, c=comm: c.custom_all_reduce(t),
lambda t, c=comm: c.should_custom_ar(t),
comm.capture(),
"1stage", # env variable value
{"VLLM_CUSTOM_ALLREDUCE_ALGO": "1stage"},
None, # no destroy function
)
)
# CustomAllreduce two-shot
@@ -190,7 +211,8 @@ class CommunicatorBenchmark:
lambda t, c=comm: c.custom_all_reduce(t),
lambda t, c=comm: c.should_custom_ar(t),
comm.capture(),
"2stage", # env variable value
{"VLLM_CUSTOM_ALLREDUCE_ALGO": "2stage"},
None, # no destroy function
)
)
@@ -202,7 +224,8 @@ class CommunicatorBenchmark:
lambda t, c=comm: c.all_reduce(t),
lambda t: True, # Always available if initialized
nullcontext(),
None, # no env variable needed
{}, # no env variable needed
None, # no destroy function
)
)
communicators.append(
@@ -211,7 +234,8 @@ class CommunicatorBenchmark:
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
lambda t: True, # Always available if initialized
nullcontext(),
None, # no env variable needed
{}, # no env variable needed
None, # no destroy function
)
)
@@ -223,7 +247,8 @@ class CommunicatorBenchmark:
lambda t, c=comm: c.all_reduce(t),
lambda t, c=comm: c.should_use_symm_mem(t),
nullcontext(),
None, # no env variable needed
{}, # no env variable needed
None, # no destroy function
)
)
@@ -235,29 +260,67 @@ class CommunicatorBenchmark:
lambda t, c=comm: c.all_reduce(t),
lambda t, c=comm: c.should_use_symm_mem(t),
nullcontext(),
None, # no env variable needed
{}, # no env variable needed
None, # no destroy function needed
)
)
if self.fi_ar_comm is not None:
comm = self.fi_ar_comm
communicators.append(
(
"flashinfer_trtllm",
lambda t, c=comm: c.all_reduce(t),
lambda t, c=comm: c.should_use_fi_ar(t),
nullcontext(),
{"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "trtllm"},
lambda c=comm: c.destroy(),
)
)
communicators.append(
(
"flashinfer_mnnvl",
lambda t, c=comm: c.all_reduce(t),
lambda t, c=comm: c.should_use_fi_ar(t),
nullcontext(),
{"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "mnnvl"},
lambda c=comm: c.destroy(),
)
)
# Benchmark each communicator
for name, allreduce_fn, should_use_fn, context, env_var in communicators:
# Set environment variable if needed
if env_var is not None:
os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var
else:
# Clear the environment variable to avoid interference
os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None)
latency = self.benchmark_allreduce_single(
sequence_length,
allreduce_fn,
should_use_fn,
context,
num_warmup,
num_trials,
)
if latency is not None:
results[name] = latency
for (
name,
allreduce_fn,
should_use_fn,
context,
env_dict,
destroy_fn,
) in communicators:
# Save original values and apply new environment variables
saved_env = {key: os.environ.get(key) for key in env_dict}
for key, value in env_dict.items():
os.environ[key] = value
try:
latency = self.benchmark_allreduce_single(
sequence_length,
allreduce_fn,
should_use_fn,
context,
num_warmup,
num_trials,
)
if latency is not None:
results[name] = latency
finally:
if destroy_fn is not None:
destroy_fn()
# Restore environment variables to their original state
for key, original_value in saved_env.items():
if original_value is None:
os.environ.pop(key, None)
else:
os.environ[key] = original_value
return results