[Kernel] Refactor FlashInfer allreduce for mnnvl backend (#34109)

Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
This commit is contained in:
Hanjie Qiu
2026-02-25 19:17:20 -08:00
committed by GitHub
parent 2aa4140402
commit 71dfce6aa6
7 changed files with 593 additions and 180 deletions

View File

@@ -142,7 +142,6 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
*(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
)
self.wq, self.wscale = list(wq_gen), list(wscale_gen)
print(f"{self.wq=}, {self.wscale=}")
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
@@ -199,6 +198,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif(
not find_spec("flashinfer")
@@ -215,6 +215,7 @@ def test_all_reduce_fusion_pass_replace(
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
):
num_processes = 2
if (
@@ -238,6 +239,7 @@ def test_all_reduce_fusion_pass_replace(
dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
),
nprocs=nprocs,
)
@@ -255,6 +257,7 @@ def all_reduce_fusion_pass_on_test_model(
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
flashinfer_allreduce_backend,
):
set_random_seed(0)
@@ -270,6 +273,7 @@ def all_reduce_fusion_pass_on_test_model(
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": flashinfer_allreduce_backend,
}
)
@@ -317,6 +321,10 @@ def all_reduce_fusion_pass_on_test_model(
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
results_unfused = model(hidden_states)
results_fused = compiled_model(hidden_states)
torch.testing.assert_close(results_unfused, results_fused, atol=1e-2, rtol=1e-2)
assert all_reduce_fusion_pass.matched_count == 4, (
f"{all_reduce_fusion_pass.matched_count=}"
)