[Kernel] Refactor FlashInfer allreduce for mnnvl backend (#34109)
Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com> Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
This commit is contained in:
@@ -30,6 +30,9 @@ import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
||||
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
|
||||
FlashInferAllReduce,
|
||||
)
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
PyNcclCommunicator,
|
||||
register_nccl_symmetric_ops,
|
||||
@@ -44,7 +47,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Default sequence lengths to benchmark
|
||||
DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192]
|
||||
DEFAULT_SEQUENCE_LENGTHS = [16, 64, 128, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
# Fixed hidden size and dtype for all benchmarks
|
||||
HIDDEN_SIZE = 8192
|
||||
@@ -81,6 +84,7 @@ class CommunicatorBenchmark:
|
||||
self.symm_mem_comm = None
|
||||
self.symm_mem_comm_multimem = None
|
||||
self.symm_mem_comm_two_shot = None
|
||||
self.fi_ar_comm = None
|
||||
|
||||
self._init_communicators()
|
||||
|
||||
@@ -161,6 +165,22 @@ class CommunicatorBenchmark:
|
||||
)
|
||||
self.symm_mem_comm_two_shot = None
|
||||
|
||||
try:
|
||||
self.fi_ar_comm = FlashInferAllReduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
)
|
||||
if not self.fi_ar_comm.disabled:
|
||||
logger.info("Rank %s: FlashInferAllReduce initialized", self.rank)
|
||||
else:
|
||||
logger.info("Rank %s: FlashInferAllReduce disabled", self.rank)
|
||||
self.fi_ar_comm = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize FlashInferAllReduce: %s", self.rank, e
|
||||
)
|
||||
self.fi_ar_comm = None
|
||||
|
||||
def benchmark_allreduce(
|
||||
self, sequence_length: int, num_warmup: int, num_trials: int
|
||||
) -> dict[str, float]:
|
||||
@@ -180,7 +200,8 @@ class CommunicatorBenchmark:
|
||||
lambda t, c=comm: c.custom_all_reduce(t),
|
||||
lambda t, c=comm: c.should_custom_ar(t),
|
||||
comm.capture(),
|
||||
"1stage", # env variable value
|
||||
{"VLLM_CUSTOM_ALLREDUCE_ALGO": "1stage"},
|
||||
None, # no destroy function
|
||||
)
|
||||
)
|
||||
# CustomAllreduce two-shot
|
||||
@@ -190,7 +211,8 @@ class CommunicatorBenchmark:
|
||||
lambda t, c=comm: c.custom_all_reduce(t),
|
||||
lambda t, c=comm: c.should_custom_ar(t),
|
||||
comm.capture(),
|
||||
"2stage", # env variable value
|
||||
{"VLLM_CUSTOM_ALLREDUCE_ALGO": "2stage"},
|
||||
None, # no destroy function
|
||||
)
|
||||
)
|
||||
|
||||
@@ -202,7 +224,8 @@ class CommunicatorBenchmark:
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t: True, # Always available if initialized
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
{}, # no env variable needed
|
||||
None, # no destroy function
|
||||
)
|
||||
)
|
||||
communicators.append(
|
||||
@@ -211,7 +234,8 @@ class CommunicatorBenchmark:
|
||||
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
|
||||
lambda t: True, # Always available if initialized
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
{}, # no env variable needed
|
||||
None, # no destroy function
|
||||
)
|
||||
)
|
||||
|
||||
@@ -223,7 +247,8 @@ class CommunicatorBenchmark:
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
{}, # no env variable needed
|
||||
None, # no destroy function
|
||||
)
|
||||
)
|
||||
|
||||
@@ -235,29 +260,67 @@ class CommunicatorBenchmark:
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
{}, # no env variable needed
|
||||
None, # no destroy function needed
|
||||
)
|
||||
)
|
||||
|
||||
if self.fi_ar_comm is not None:
|
||||
comm = self.fi_ar_comm
|
||||
communicators.append(
|
||||
(
|
||||
"flashinfer_trtllm",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_fi_ar(t),
|
||||
nullcontext(),
|
||||
{"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "trtllm"},
|
||||
lambda c=comm: c.destroy(),
|
||||
)
|
||||
)
|
||||
communicators.append(
|
||||
(
|
||||
"flashinfer_mnnvl",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_fi_ar(t),
|
||||
nullcontext(),
|
||||
{"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "mnnvl"},
|
||||
lambda c=comm: c.destroy(),
|
||||
)
|
||||
)
|
||||
|
||||
# Benchmark each communicator
|
||||
for name, allreduce_fn, should_use_fn, context, env_var in communicators:
|
||||
# Set environment variable if needed
|
||||
if env_var is not None:
|
||||
os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var
|
||||
else:
|
||||
# Clear the environment variable to avoid interference
|
||||
os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None)
|
||||
|
||||
latency = self.benchmark_allreduce_single(
|
||||
sequence_length,
|
||||
allreduce_fn,
|
||||
should_use_fn,
|
||||
context,
|
||||
num_warmup,
|
||||
num_trials,
|
||||
)
|
||||
if latency is not None:
|
||||
results[name] = latency
|
||||
for (
|
||||
name,
|
||||
allreduce_fn,
|
||||
should_use_fn,
|
||||
context,
|
||||
env_dict,
|
||||
destroy_fn,
|
||||
) in communicators:
|
||||
# Save original values and apply new environment variables
|
||||
saved_env = {key: os.environ.get(key) for key in env_dict}
|
||||
for key, value in env_dict.items():
|
||||
os.environ[key] = value
|
||||
try:
|
||||
latency = self.benchmark_allreduce_single(
|
||||
sequence_length,
|
||||
allreduce_fn,
|
||||
should_use_fn,
|
||||
context,
|
||||
num_warmup,
|
||||
num_trials,
|
||||
)
|
||||
if latency is not None:
|
||||
results[name] = latency
|
||||
finally:
|
||||
if destroy_fn is not None:
|
||||
destroy_fn()
|
||||
# Restore environment variables to their original state
|
||||
for key, original_value in saved_env.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = original_value
|
||||
|
||||
return results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user