diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py index 7b453fe7b..d1005461a 100644 --- a/benchmarks/kernels/benchmark_device_communicators.py +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -30,6 +30,9 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce +from vllm.distributed.device_communicators.flashinfer_all_reduce import ( + FlashInferAllReduce, +) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator, register_nccl_symmetric_ops, @@ -44,7 +47,7 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser logger = init_logger(__name__) # Default sequence lengths to benchmark -DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192] +DEFAULT_SEQUENCE_LENGTHS = [16, 64, 128, 512, 1024, 2048, 4096, 8192] # Fixed hidden size and dtype for all benchmarks HIDDEN_SIZE = 8192 @@ -81,6 +84,7 @@ class CommunicatorBenchmark: self.symm_mem_comm = None self.symm_mem_comm_multimem = None self.symm_mem_comm_two_shot = None + self.fi_ar_comm = None self._init_communicators() @@ -161,6 +165,22 @@ class CommunicatorBenchmark: ) self.symm_mem_comm_two_shot = None + try: + self.fi_ar_comm = FlashInferAllReduce( + group=self.cpu_group, + device=self.device, + ) + if not self.fi_ar_comm.disabled: + logger.info("Rank %s: FlashInferAllReduce initialized", self.rank) + else: + logger.info("Rank %s: FlashInferAllReduce disabled", self.rank) + self.fi_ar_comm = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize FlashInferAllReduce: %s", self.rank, e + ) + self.fi_ar_comm = None + def benchmark_allreduce( self, sequence_length: int, num_warmup: int, num_trials: int ) -> dict[str, float]: @@ -180,7 +200,8 @@ class CommunicatorBenchmark: lambda t, c=comm: c.custom_all_reduce(t), lambda t, c=comm: c.should_custom_ar(t), comm.capture(), - "1stage", # env variable value + {"VLLM_CUSTOM_ALLREDUCE_ALGO": "1stage"}, + None, # no destroy function ) ) # CustomAllreduce two-shot @@ -190,7 +211,8 @@ class CommunicatorBenchmark: lambda t, c=comm: c.custom_all_reduce(t), lambda t, c=comm: c.should_custom_ar(t), comm.capture(), - "2stage", # env variable value + {"VLLM_CUSTOM_ALLREDUCE_ALGO": "2stage"}, + None, # no destroy function ) ) @@ -202,7 +224,8 @@ class CommunicatorBenchmark: lambda t, c=comm: c.all_reduce(t), lambda t: True, # Always available if initialized nullcontext(), - None, # no env variable needed + {}, # no env variable needed + None, # no destroy function ) ) communicators.append( @@ -211,7 +234,8 @@ class CommunicatorBenchmark: lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t), lambda t: True, # Always available if initialized nullcontext(), - None, # no env variable needed + {}, # no env variable needed + None, # no destroy function ) ) @@ -223,7 +247,8 @@ class CommunicatorBenchmark: lambda t, c=comm: c.all_reduce(t), lambda t, c=comm: c.should_use_symm_mem(t), nullcontext(), - None, # no env variable needed + {}, # no env variable needed + None, # no destroy function ) ) @@ -235,29 +260,67 @@ class CommunicatorBenchmark: lambda t, c=comm: c.all_reduce(t), lambda t, c=comm: c.should_use_symm_mem(t), nullcontext(), - None, # no env variable needed + {}, # no env variable needed + None, # no destroy function needed + ) + ) + + if self.fi_ar_comm is not None: + comm = self.fi_ar_comm + communicators.append( + ( + "flashinfer_trtllm", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_fi_ar(t), + nullcontext(), + {"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "trtllm"}, + lambda c=comm: c.destroy(), + ) + ) + communicators.append( + ( + "flashinfer_mnnvl", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_fi_ar(t), + nullcontext(), + {"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "mnnvl"}, + lambda c=comm: c.destroy(), ) ) # Benchmark each communicator - for name, allreduce_fn, should_use_fn, context, env_var in communicators: - # Set environment variable if needed - if env_var is not None: - os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var - else: - # Clear the environment variable to avoid interference - os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None) - - latency = self.benchmark_allreduce_single( - sequence_length, - allreduce_fn, - should_use_fn, - context, - num_warmup, - num_trials, - ) - if latency is not None: - results[name] = latency + for ( + name, + allreduce_fn, + should_use_fn, + context, + env_dict, + destroy_fn, + ) in communicators: + # Save original values and apply new environment variables + saved_env = {key: os.environ.get(key) for key in env_dict} + for key, value in env_dict.items(): + os.environ[key] = value + try: + latency = self.benchmark_allreduce_single( + sequence_length, + allreduce_fn, + should_use_fn, + context, + num_warmup, + num_trials, + ) + if latency is not None: + results[name] = latency + finally: + if destroy_fn is not None: + destroy_fn() + # Restore environment variables to their original state + for key, original_value in saved_env.items(): + if original_value is None: + os.environ.pop(key, None) + else: + os.environ[key] = original_value return results diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index 633529edf..e18f6a758 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -5,8 +5,11 @@ Benchmark for FlashInfer fused collective operations vs standard operations. This benchmark compares: -1. FlashInfer's allreduce_fusion (fused allreduce + rmsnorm + optional quant) -2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations +1. FlashInfer's allreduce_fusion with trtllm backend + (fused allreduce + rmsnorm + optional FP8/FP4 quant) +2. FlashInfer's allreduce_fusion with mnnvl backend + (fused allreduce + rmsnorm only, no quantization support) +3. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations Usage with torchrun: torchrun --nproc_per_node=2 benchmark_fused_collective.py @@ -48,8 +51,12 @@ SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant logger = init_logger(__name__) # Try to import FlashInfer +TorchDistBackend = None try: import flashinfer.comm as flashinfer_comm # type: ignore + from flashinfer.comm.mnnvl import ( # type: ignore + TorchDistBackend, + ) if not ( hasattr(flashinfer_comm, "allreduce_fusion") @@ -74,11 +81,15 @@ _FI_MAX_SIZES = { 8: 64 * MiB, # 64MB } -# Global workspace tensor for FlashInfer -_FI_WORKSPACE = None +# Global workspace tensors for FlashInfer (keyed by backend name) +_FI_WORKSPACES: dict = {} + +# Backends to benchmark +FLASHINFER_BACKENDS = ["trtllm", "mnnvl"] def setup_flashinfer_workspace( + backend: str, world_size: int, rank: int, hidden_dim: int, @@ -86,41 +97,54 @@ def setup_flashinfer_workspace( dtype: torch.dtype, ): """Setup FlashInfer workspace for fused allreduce operations.""" - global _FI_WORKSPACE + global FI_WORKSPACES if flashinfer_comm is None: - return None, None + return None if world_size not in _FI_MAX_SIZES: logger.warning("FlashInfer not supported for world size %s", world_size) - return None, None + return None try: + kwargs = {} + if TorchDistBackend is not None: + kwargs["comm_backend"] = TorchDistBackend(group=dist.group.WORLD) + workspace = flashinfer_comm.create_allreduce_fusion_workspace( - backend="trtllm", + backend=backend, world_size=world_size, rank=rank, max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, + **kwargs, ) - _FI_WORKSPACE = workspace + _FI_WORKSPACES[backend] = workspace return workspace except Exception as e: - logger.error("Failed to setup FlashInfer workspace: %s", e) + logger.error( + "Failed to setup FlashInfer workspace (backend=%s): %s", backend, e + ) return None -def cleanup_flashinfer_workspace(workspace): - """Cleanup FlashInfer workspace.""" - if flashinfer_comm is None or workspace is None: +def cleanup_flashinfer_workspaces(): + """Cleanup all FlashInfer workspaces.""" + if flashinfer_comm is None: return - try: - workspace.destroy() - except Exception as e: - logger.error("Failed to cleanup FlashInfer workspace: %s", e) + for backend, workspace in _FI_WORKSPACES.items(): + try: + workspace.destroy() + except Exception as e: + logger.error( + "Failed to cleanup FlashInfer workspace (backend=%s): %s", + backend, + e, + ) + _FI_WORKSPACES.clear() class FlashInferFusedAllReduceParams: @@ -134,7 +158,7 @@ class FlashInferFusedAllReduceParams: self.fp32_acc = True self.max_token_num = max_token_num - def get_trtllm_fused_allreduce_kwargs(self): + def get_flashinfer_fused_allreduce_kwargs(self): return { "launch_with_pdl": self.launch_with_pdl, "fp32_acc": self.fp32_acc, @@ -147,11 +171,12 @@ def flashinfer_fused_allreduce_rmsnorm( rms_gamma: torch.Tensor, rms_eps: float, allreduce_params: "FlashInferFusedAllReduceParams", + workspace: object, use_oneshot: bool, norm_out: torch.Tensor | None = None, ): """FlashInfer fused allreduce + rmsnorm operation.""" - if flashinfer_comm is None or _FI_WORKSPACE is None: + if flashinfer_comm is None or workspace is None: raise RuntimeError("FlashInfer not available or workspace not initialized") if norm_out is None: @@ -160,9 +185,13 @@ def flashinfer_fused_allreduce_rmsnorm( else: residual_out = input_tensor + layout_code = None + if workspace.backend == "trtllm": + layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4 + flashinfer_comm.allreduce_fusion( input=input_tensor, - workspace=_FI_WORKSPACE, + workspace=workspace, pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, residual_in=residual, residual_out=residual_out, @@ -171,10 +200,10 @@ def flashinfer_fused_allreduce_rmsnorm( rms_eps=rms_eps, quant_out=None, scale_out=None, - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + layout_code=layout_code, scale_factor=None, use_oneshot=use_oneshot, - **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + **allreduce_params.get_flashinfer_fused_allreduce_kwargs(), ) @@ -185,12 +214,16 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( rms_eps: float, scale_factor: torch.Tensor, allreduce_params: FlashInferFusedAllReduceParams, + workspace: object, use_oneshot: bool = True, norm_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None, ): - """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" - if flashinfer_comm is None or _FI_WORKSPACE is None: + """FlashInfer fused allreduce + rmsnorm + FP8 quantization. + + Note: Only supported by the trtllm backend. + """ + if flashinfer_comm is None or workspace is None: raise RuntimeError("FlashInfer not available or workspace not initialized") if norm_out is None: @@ -201,7 +234,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( flashinfer_comm.allreduce_fusion( input=input_tensor, - workspace=_FI_WORKSPACE, + workspace=workspace, pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, residual_in=residual, residual_out=residual_out, @@ -213,7 +246,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, scale_factor=scale_factor, use_oneshot=use_oneshot, - **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + **allreduce_params.get_flashinfer_fused_allreduce_kwargs(), ) @@ -224,13 +257,17 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( rms_eps: float, input_global_scale: torch.Tensor, allreduce_params: FlashInferFusedAllReduceParams, + workspace: object, quant_out: torch.Tensor, use_oneshot: bool, output_scale: torch.Tensor, norm_out: torch.Tensor | None = None, ): - """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" - if flashinfer_comm is None or _FI_WORKSPACE is None: + """FlashInfer fused allreduce + rmsnorm + FP4 quantization. + + Note: Only supported by the trtllm backend. + """ + if flashinfer_comm is None or workspace is None: raise RuntimeError("FlashInfer not available or workspace not initialized") if norm_out is None: @@ -241,7 +278,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( flashinfer_comm.allreduce_fusion( input=input_tensor, - workspace=_FI_WORKSPACE, + workspace=workspace, pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, residual_in=residual, residual_out=residual_out, @@ -253,7 +290,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, scale_factor=input_global_scale, use_oneshot=use_oneshot, - **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + **allreduce_params.get_flashinfer_fused_allreduce_kwargs(), ) @@ -386,13 +423,16 @@ def run_benchmarks( dtype: torch.dtype, use_residual: bool, allreduce_params: FlashInferFusedAllReduceParams | None, + workspaces: dict, quant_modes: set[str], no_oneshot: bool, ): """Run all benchmarks for given configuration. Args: - quant_mode: "none", "fp8_only", "fp4_only", or "all" + allreduce_params: Shared parameters for FlashInfer fused allreduce. + workspaces: Dict mapping backend name ("trtllm", "mnnvl") to workspace. + quant_modes: Set of quantization modes: "none", "fp8", "fp4". """ ( input_tensor, @@ -454,10 +494,11 @@ def run_benchmarks( logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") - # FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot - if flashinfer_comm is not None and allreduce_params is not None: + # FlashInfer Fused AllReduce + RMSNorm (all backends) + for backend, workspace in workspaces.items(): for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" + key = f"flashinfer_{backend}_fused_allreduce_rmsnorm{suffix}" try: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm, @@ -467,14 +508,17 @@ def run_benchmarks( rms_gamma=rms_gamma, rms_eps=rms_eps, allreduce_params=allreduce_params, + workspace=workspace, use_oneshot=use_oneshot, ) - results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = time_ms + results[key] = time_ms except Exception as e: - logger.error("FlashInfer Fused AllReduce+RMSNorm failed: %s", e) - results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = float( - "inf" + logger.error( + "FlashInfer (%s) Fused AllReduce+RMSNorm failed: %s", + backend, + e, ) + results[key] = float("inf") if "fp8" in quant_modes: # Standard AllReduce + RMSNorm + FP8 Quant @@ -540,10 +584,12 @@ def run_benchmarks( "inf" ) - # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot - if flashinfer_comm is not None and allreduce_params is not None: + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant (trtllm only) + if "trtllm" in workspaces: + trtllm_ws = workspaces["trtllm"] for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" + key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp8_quant{suffix}" try: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm_fp8_quant, @@ -555,19 +601,16 @@ def run_benchmarks( scale_factor=scale_fp8, quant_out=quant_out_fp8, allreduce_params=allreduce_params, + workspace=trtllm_ws, use_oneshot=use_oneshot, ) - results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = ( - time_ms - ) + results[key] = time_ms except Exception as e: logger.error( - "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", + "FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP8 failed: %s", e, ) - results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = ( - float("inf") - ) + results[key] = float("inf") if "fp4" in quant_modes and current_platform.has_device_capability(100): # Standard AllReduce + RMSNorm + FP4 Quant @@ -627,10 +670,12 @@ def run_benchmarks( "inf" ) - # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot - if flashinfer_comm is not None and allreduce_params is not None: + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant (trtllm only) + if "trtllm" in workspaces: + trtllm_ws = workspaces["trtllm"] for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" + key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp4_quant{suffix}" try: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm_fp4_quant, @@ -641,49 +686,18 @@ def run_benchmarks( rms_eps=rms_eps, input_global_scale=scale_fp4, allreduce_params=allreduce_params, + workspace=trtllm_ws, quant_out=fp4_quant_out, output_scale=fp4_output_scale, use_oneshot=use_oneshot, ) - results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = ( - time_ms - ) + results[key] = time_ms except Exception as e: logger.error( - "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", + "FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP4 failed: %s", e, ) - results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = ( - float("inf") - ) - - # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot - if flashinfer_comm is not None and allreduce_params is not None: - try: - time_ms = benchmark_operation( - flashinfer_fused_allreduce_rmsnorm_fp4_quant, - input_tensor, - residual=residual, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - input_global_scale=scale_fp4, - allreduce_params=allreduce_params, - quant_out=fp4_quant_out, - output_scale=fp4_output_scale, - use_oneshot=False, - ) - results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = ( - time_ms - ) - except Exception as e: - logger.error( - "FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s", - e, - ) - results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float( - "inf" - ) + results[key] = float("inf") return results @@ -1021,8 +1035,7 @@ def main(): configs = list(itertools.product(args.num_tokens, dtypes, residual_options)) - # Setup FlashInfer workspace if available - workspace = None + # Setup FlashInfer workspaces for all backends allreduce_params = None if flashinfer_comm is not None: @@ -1037,15 +1050,17 @@ def main(): args.hidden_dim * max_element_size ) - workspace = setup_flashinfer_workspace( - world_size, - rank, - args.hidden_dim, - max_num_token, - dtype=workspace_dtype, - ) + for backend in FLASHINFER_BACKENDS: + setup_flashinfer_workspace( + backend=backend, + world_size=world_size, + rank=rank, + hidden_dim=args.hidden_dim, + max_token_num=max_num_token, + dtype=workspace_dtype, + ) - if workspace is not None: + if _FI_WORKSPACES: allreduce_params = FlashInferFusedAllReduceParams( max_token_num=max_num_token, ) @@ -1071,6 +1086,7 @@ def main(): dtype, use_residual, allreduce_params, + workspaces=_FI_WORKSPACES, quant_modes=quant_modes, no_oneshot=args.no_oneshot, ) @@ -1109,11 +1125,13 @@ def main(): finally: # Cleanup - if workspace is not None: - cleanup_flashinfer_workspace(workspace) + cleanup_flashinfer_workspaces() dist.barrier() if __name__ == "__main__": - main() + from vllm.config import VllmConfig, set_current_vllm_config + + with set_current_vllm_config(VllmConfig()): + main() diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index d48f22970..6d5113b1e 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -142,7 +142,6 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): *(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale)) ) self.wq, self.wscale = list(wq_gen), list(wscale_gen) - print(f"{self.wq=}, {self.wscale=}") def forward(self, hidden_states): # avoid having graph input be an arg to a pattern directly @@ -199,6 +198,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module): @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) +@pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") @pytest.mark.skipif( not find_spec("flashinfer") @@ -215,6 +215,7 @@ def test_all_reduce_fusion_pass_replace( dtype: torch.dtype, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, + flashinfer_allreduce_backend, ): num_processes = 2 if ( @@ -238,6 +239,7 @@ def test_all_reduce_fusion_pass_replace( dtype, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, + flashinfer_allreduce_backend, ), nprocs=nprocs, ) @@ -255,6 +257,7 @@ def all_reduce_fusion_pass_on_test_model( dtype: torch.dtype, enable_rms_norm_custom_op, enable_quant_fp8_custom_op, + flashinfer_allreduce_backend, ): set_random_seed(0) @@ -270,6 +273,7 @@ def all_reduce_fusion_pass_on_test_model( "WORLD_SIZE": str(world_size), "MASTER_ADDR": "localhost", "MASTER_PORT": "12345", + "VLLM_FLASHINFER_ALLREDUCE_BACKEND": flashinfer_allreduce_backend, } ) @@ -317,6 +321,10 @@ def all_reduce_fusion_pass_on_test_model( compiled_model = torch.compile(model, backend=backend) compiled_model(hidden_states) + results_unfused = model(hidden_states) + results_fused = compiled_model(hidden_states) + torch.testing.assert_close(results_unfused, results_fused, atol=1e-2, rtol=1e-2) + assert all_reduce_fusion_pass.matched_count == 4, ( f"{all_reduce_fusion_pass.matched_count=}" ) diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index b6a1314af..44dc3d67b 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -22,7 +22,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) from vllm.platforms import current_platform -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch_utils import ( + direct_register_custom_op, +) from ..inductor_pass import enable_fake_mode from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass @@ -44,8 +46,6 @@ if find_spec("flashinfer"): except ImportError: pass -logger = init_logger(__name__) - if hasattr(torch.ops._C, "scaled_fp4_quant"): STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default @@ -82,7 +82,16 @@ _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = { if flashinfer_comm is not None: - _FI_WORKSPACE = None + from vllm.distributed.device_communicators.flashinfer_all_reduce import ( + destroy_fi_ar_workspace, + get_fi_ar_quant_workspace, + get_fi_ar_workspace, + initialize_fi_ar_quant_workspace, + initialize_fi_ar_workspace, + ) + + ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern + MiB = 1024 * 1024 def call_trtllm_fused_allreduce_norm( @@ -122,9 +131,19 @@ if flashinfer_comm is not None: max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB ) - assert _FI_WORKSPACE is not None, ( - "Flashinfer must be enabled when using flashinfer" + # Select workspace based on pattern: quant patterns use the + # trtllm quant workspace, non-quant patterns use the primary workspace. + if pattern_code in ( + ar_fusion_patterns.kARResidualRMSNormFP8Quant, + ar_fusion_patterns.kARResidualRMSNormFP4Quant, + ): + workspace = get_fi_ar_quant_workspace() + else: + workspace = get_fi_ar_workspace() + assert workspace is not None, ( + "Flashinfer workspace must be initialized when using flashinfer" ) + assert flashinfer_comm is not None if norm_out is None: norm_out = allreduce_in residual_out = residual @@ -133,25 +152,30 @@ if flashinfer_comm is not None: # as flashinfer does not support rms_norm # and allreduce_out together residual_out = allreduce_in - # For the sizes that are smaller than the max size, - # we only use flashinfer one shot allreduce + + layout_code = None + # layout_code only supported by trtllm backend + if workspace.backend == "trtllm": + # in vllm we only support swizzled layout + layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4 + flashinfer_comm.allreduce_fusion( input=allreduce_in, - workspace=_FI_WORKSPACE, + workspace=workspace, pattern=pattern_code, - residual_in=residual, + launch_with_pdl=launch_with_pdl, + output=None, residual_out=residual_out, norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - launch_with_pdl=launch_with_pdl, - use_oneshot=use_oneshot, - fp32_acc=fp32_acc, quant_out=quant_out, scale_out=scale_out, - # in vllm we only support swizzled layout - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + residual_in=residual, + rms_gamma=rms_gamma, + rms_eps=rms_eps, scale_factor=scale_factor, + layout_code=layout_code, + use_oneshot=use_oneshot, + fp32_acc=fp32_acc, ) def call_trtllm_fused_allreduce_norm_fake( @@ -729,29 +753,36 @@ class AllReduceFusionPass(VllmPatternMatcherPass): scope="global", ) - try: - self.workspace = flashinfer_comm.create_allreduce_fusion_workspace( - backend="trtllm", - world_size=self.tp_size, - rank=rank, - max_token_num=self.max_token_num, - hidden_dim=self.hidden_dim, - dtype=self.model_dtype, - ) - except RuntimeError as e: - if "multicast" not in str(e).lower(): - raise - logger.warning_once( - "AllReduce fusion pass is disabled: flashinfer workspace " - "creation failed: %s. This is expected on GPUs without " - "NVSwitch (e.g., NVLink bridge-only or PCIe topologies). " - "Falling back to non-fused allreduce.", - str(e), - ) - return + for workspace_init_fn in [ + initialize_fi_ar_workspace, + initialize_fi_ar_quant_workspace, + ]: + try: + workspace_init_fn( + world_size=self.tp_size, + rank=rank, + max_token_num=self.max_token_num, + hidden_dim=self.hidden_dim, + dtype=self.model_dtype, + group=self.group, + ) + except Exception as e: + if "multicast" in str(e).lower(): + logger.warning( + "AllReduce fusion pass is disabled: flashinfer workspace " + "creation failed: %s. This is expected on GPUs without " + "NVSwitch (e.g., NVLink bridge-only or PCIe topologies). " + "Falling back to non-fused allreduce.", + str(e), + ) + else: + logger.warning( + "Failed to initialize FlashInfer All Reduce workspace: %s. " + "AllReduce fusion pass will be disabled.", + e, + ) + return - global _FI_WORKSPACE - _FI_WORKSPACE = self.workspace self.allreduce_params = FlashInferFusedAllReduceParams( world_size=self.tp_size, max_token_num=self.max_token_num, @@ -762,32 +793,34 @@ class AllReduceFusionPass(VllmPatternMatcherPass): @enable_fake_mode def register_patterns(self) -> None: + supports_quantization = get_fi_ar_quant_workspace() is not None for epsilon in [1e-5, 1e-6]: - AllReduceFusedRMSNormStaticQuantFP8Pattern( - epsilon, - self.model_dtype, - self.device, - self.allreduce_params, - ).register(self.patterns) - AllReduceFusedAddRMSNormStaticQuantFP8Pattern( - epsilon, - self.model_dtype, - self.device, - self.allreduce_params, - ).register(self.patterns) - if current_platform.has_device_capability(100): - AllReduceFusedRMSNormStaticQuantNVFP4Pattern( + if supports_quantization: + AllReduceFusedRMSNormStaticQuantFP8Pattern( epsilon, self.model_dtype, self.device, self.allreduce_params, ).register(self.patterns) - AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern( + AllReduceFusedAddRMSNormStaticQuantFP8Pattern( epsilon, self.model_dtype, self.device, self.allreduce_params, ).register(self.patterns) + if current_platform.has_device_capability(100): + AllReduceFusedRMSNormStaticQuantNVFP4Pattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) AllReduceRMSNormPattern( epsilon, self.model_dtype, @@ -825,6 +858,5 @@ class AllReduceFusionPass(VllmPatternMatcherPass): def __del__(self) -> None: if getattr(self, "disabled", True): return - if getattr(self, "workspace", None) is not None: - with contextlib.suppress(Exception): - self.workspace.destroy() + with contextlib.suppress(Exception): + destroy_fi_ar_workspace() diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 4c78871e1..62e2b9037 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -34,19 +34,25 @@ class CudaCommunicator(DeviceCommunicatorBase): # custom allreduce or torch symm mem can be used only by tp use_custom_allreduce = False use_torch_symm_mem = False + use_flashinfer_allreduce = False else: from vllm.distributed.parallel_state import _ENABLE_CUSTOM_ALL_REDUCE use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM + use_flashinfer_allreduce = envs.VLLM_ALLREDUCE_USE_FLASHINFER self.use_custom_allreduce = use_custom_allreduce self.use_torch_symm_mem = use_torch_symm_mem + self.use_flashinfer_allreduce = use_flashinfer_allreduce # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce, ) + from vllm.distributed.device_communicators.flashinfer_all_reduce import ( + FlashInferAllReduce, + ) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.quick_all_reduce import ( QuickAllReduce, @@ -65,12 +71,20 @@ class CudaCommunicator(DeviceCommunicatorBase): self.ca_comm: CustomAllreduce | None = None self.qr_comm: QuickAllReduce | None = None self.symm_mem_comm: SymmMemCommunicator | None = None + self.fi_ar_comm: FlashInferAllReduce | None = None + if use_torch_symm_mem and current_platform.is_cuda(): self.symm_mem_comm = SymmMemCommunicator( group=self.cpu_group, device=self.device, ) + if self.use_flashinfer_allreduce and self.world_size > 1: + self.fi_ar_comm = FlashInferAllReduce( + group=self.cpu_group, + device=self.device, + ) + if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -136,7 +150,7 @@ class CudaCommunicator(DeviceCommunicatorBase): out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_) if out is not None: return out - # always try quick reduce first, then custom allreduce, + # always try quick reduce first, then flashinfer, then custom allreduce, # and then pynccl. (quick reduce just for ROCM MI3*) qr_comm = self.qr_comm if ( @@ -147,6 +161,15 @@ class CudaCommunicator(DeviceCommunicatorBase): out = qr_comm.quick_all_reduce(input_) assert out is not None return out + fi_ar_comm = self.fi_ar_comm + if ( + fi_ar_comm is not None + and not fi_ar_comm.disabled + and fi_ar_comm.should_use_fi_ar(input_) + ): + out = fi_ar_comm.all_reduce(input_) + assert out is not None + return out ca_comm = self.ca_comm if ( ca_comm is not None @@ -270,6 +293,9 @@ class CudaCommunicator(DeviceCommunicatorBase): self.pynccl_comm = None if self.ca_comm is not None: self.ca_comm = None + if self.fi_ar_comm is not None: + self.fi_ar_comm.destroy() + self.fi_ar_comm = None if self.all2all_manager is not None: self.all2all_manager.destroy() self.all2all_manager = None diff --git a/vllm/distributed/device_communicators/flashinfer_all_reduce.py b/vllm/distributed/device_communicators/flashinfer_all_reduce.py new file mode 100644 index 000000000..ea16c9376 --- /dev/null +++ b/vllm/distributed/device_communicators/flashinfer_all_reduce.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm.config.compilation import PassConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +fi_ar_available = False +try: + import flashinfer.comm as flashinfer_comm # type: ignore[no-redef] + from flashinfer.comm.mnnvl import ( + TorchDistBackend, # type: ignore[import-not-found, no-redef] + ) + + fi_ar_available = hasattr(flashinfer_comm, "allreduce_fusion") +except ImportError: + pass + +# Global workspace for standalone allreduce and non-quant ar+rms fusion +_fi_ar_workspace = None +# Extra workspace for quant fusion patterns (only supported by trtllm backend) +# Only created if primary workspace is not already trtllm +_fi_ar_quant_workspace = None + + +def get_fi_ar_workspace(): + return _fi_ar_workspace + + +def get_fi_ar_quant_workspace(): + return _fi_ar_quant_workspace + + +def initialize_fi_ar_workspace( + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + group: ProcessGroup, +) -> None: + """ + Initialize the workspace if not already initialized. + + Currently, this function is called by either the AllReduceFusionPass + or the FlashInferAllReduce backend for standalone allreduce. + If the fusion pass is enabled via + --compilation-config.pass_config.fuse_allreduce_rms=true, + it will create the workspace first, and the standalone backend + will reuse the workspace. Otherwise, the standalone backend will + create the workspace. + """ + global _fi_ar_workspace + if _fi_ar_workspace is not None: + return + + backend = envs.VLLM_FLASHINFER_ALLREDUCE_BACKEND + comm_backend = TorchDistBackend(group=group) + _fi_ar_workspace = flashinfer_comm.create_allreduce_fusion_workspace( + backend=backend, + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + comm_backend=comm_backend, + ) + assert _fi_ar_workspace is not None + logger.debug( + "Initialized FlashInfer All Reduce workspace: backend=%s, " + "world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s", + backend, + world_size, + rank, + max_token_num, + hidden_dim, + dtype, + ) + + +def initialize_fi_ar_quant_workspace( + world_size: int, + rank: int, + max_token_num: int, + hidden_dim: int, + dtype: torch.dtype, + group: ProcessGroup, +) -> None: + """ + Initialize the workspace used by quantization fusion patterns. + + Currently this always creates a workspace for trtllm backend as only it + supports quantization fusion (FP8/FP4). If the primary workspace + is already trtllm, the quant workspace aliases to it. + """ + global _fi_ar_quant_workspace + if _fi_ar_quant_workspace is not None: + return + + # If primary workspace is already trtllm, reuse it + if _fi_ar_workspace is not None and _fi_ar_workspace.backend == "trtllm": + _fi_ar_quant_workspace = _fi_ar_workspace + return + + comm_backend = TorchDistBackend(group=group) + _fi_ar_quant_workspace = flashinfer_comm.create_allreduce_fusion_workspace( + backend="trtllm", + world_size=world_size, + rank=rank, + max_token_num=max_token_num, + hidden_dim=hidden_dim, + dtype=dtype, + comm_backend=comm_backend, + ) + assert _fi_ar_quant_workspace is not None + logger.debug( + "Initialized FlashInfer All Reduce workspace: backend=trtllm, " + "world_size=%d, rank=%d, max_token_num=%d, hidden_dim=%d, dtype=%s", + world_size, + rank, + max_token_num, + hidden_dim, + dtype, + ) + + +def destroy_fi_ar_workspace(): + global _fi_ar_workspace + global _fi_ar_quant_workspace + if ( + _fi_ar_quant_workspace is not None + and _fi_ar_quant_workspace is not _fi_ar_workspace + ): + _fi_ar_quant_workspace.destroy() + _fi_ar_quant_workspace = None + if _fi_ar_workspace is not None: + _fi_ar_workspace.destroy() + _fi_ar_workspace = None + + +class FlashInferAllReduce: + def __init__( + self, + group: ProcessGroup, + device: int | str | torch.device, + ): + self.disabled = True + + if not fi_ar_available: + logger.info( + "FlashInfer All Reduce is disabled because flashinfer is not available" + ) + return + + if not current_platform.is_cuda(): + logger.info( + "FlashInfer All Reduce is disabled because it requires CUDA platform" + ) + return + + self.group = group + self.world_size = dist.get_world_size(self.group) + self.rank = dist.get_rank(self.group) + self.device = device + if self.world_size == 1: + return + + # Use the same threshold as the allreduce-rms fusion pass + # TODO: tune the threshold + MiB = 1024 * 1024 + max_workspace_size = PassConfig.default_fi_allreduce_fusion_max_size_mb().get( + self.world_size, None + ) + if not max_workspace_size: + logger.warning( + "FlashInfer All Reduce is disabled because it " + "is not supported for world_size=%d.", + self.world_size, + ) + return + self.max_workspace_size = max_workspace_size * MiB + self.max_num_tokens = 0 + self.disabled = False + + def _ensure_workspace(self, hidden_dim: int, dtype: torch.dtype) -> bool: + """Ensure the all reduce workspace is initialized.""" + if get_fi_ar_workspace() is not None: + return True + if self.max_num_tokens == 0: + element_size = torch.tensor([], dtype=dtype, device="cpu").element_size() + self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size) + try: + initialize_fi_ar_workspace( + world_size=self.world_size, + rank=self.rank, + max_token_num=self.max_num_tokens, + hidden_dim=hidden_dim, + dtype=dtype, + group=self.group, + ) + return True + except Exception as e: + logger.warning( + "Failed to initialize FlashInfer All Reduce workspace: %s. " + "FlashInfer All Reduce will be disabled.", + e, + ) + self.disabled = True + return False + + def should_use_fi_ar(self, input_tensor: torch.Tensor) -> bool: + if self.disabled: + return False + + if not input_tensor.is_cuda: + return False + + if not input_tensor.is_contiguous(): + return False + + if len(input_tensor.shape) != 2: + return False + + num_tokens, hidden_dim = input_tensor.shape + if not self.max_num_tokens: + element_size = torch.tensor([], dtype=input_tensor.dtype).element_size() + self.max_num_tokens = self.max_workspace_size // (hidden_dim * element_size) + + if num_tokens > self.max_num_tokens: + return False + + return self._ensure_workspace(hidden_dim, input_tensor.dtype) + + def all_reduce(self, input_tensor: torch.Tensor) -> torch.Tensor: + workspace = get_fi_ar_workspace() + return flashinfer_comm.allreduce_fusion( + input=input_tensor, + workspace=workspace, + pattern=flashinfer_comm.AllReduceFusionPattern.kAllReduce, + ) + + def destroy(self): + if not self.disabled: + destroy_fi_ar_workspace() diff --git a/vllm/envs.py b/vllm/envs.py index d62438d57..d560cfc77 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -168,6 +168,7 @@ if TYPE_CHECKING: VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = ( "latency" ) + VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto" VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024 VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 @@ -206,6 +207,7 @@ if TYPE_CHECKING: VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True + VLLM_ALLREDUCE_USE_FLASHINFER: bool = False VLLM_TUNED_CONFIG_FOLDER: str | None = None VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set() VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: bool = False @@ -1290,6 +1292,14 @@ environment_variables: dict[str, Callable[[], Any]] = { "latency", ["throughput", "latency", "masked_gemm"], ), + # Flashinfer fused allreduce backend. + # "auto" will default to "mnnvl", which performs mostly same/better than "trtllm". + # But "mnnvl" backend does not support fuse with quantization. + "VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices( + "VLLM_FLASHINFER_ALLREDUCE_BACKEND", + "auto", + ["auto", "trtllm", "mnnvl"], + ), # Control the workspace buffer size for the FlashInfer backend. "VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int( os.getenv("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", str(394 * 1024 * 1024)) @@ -1448,6 +1458,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool( int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")) ), + # Whether to use FlashInfer allreduce + "VLLM_ALLREDUCE_USE_FLASHINFER": lambda: bool( + int(os.getenv("VLLM_ALLREDUCE_USE_FLASHINFER", "0")) + ), # Experimental: use this to enable MCP tool calling for non harmony models "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT": lambda: bool( int(os.getenv("VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", "0"))