[Kernel] Refactor FlashInfer allreduce for mnnvl backend (#34109)
Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com> Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
This commit is contained in:
@@ -142,7 +142,6 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
*(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
|
||||
)
|
||||
self.wq, self.wscale = list(wq_gen), list(wscale_gen)
|
||||
print(f"{self.wq=}, {self.wscale=}")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
@@ -199,6 +198,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.parametrize("flashinfer_allreduce_backend", ["trtllm", "mnnvl"])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(
|
||||
not find_spec("flashinfer")
|
||||
@@ -215,6 +215,7 @@ def test_all_reduce_fusion_pass_replace(
|
||||
dtype: torch.dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
flashinfer_allreduce_backend,
|
||||
):
|
||||
num_processes = 2
|
||||
if (
|
||||
@@ -238,6 +239,7 @@ def test_all_reduce_fusion_pass_replace(
|
||||
dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
flashinfer_allreduce_backend,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
@@ -255,6 +257,7 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
dtype: torch.dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
flashinfer_allreduce_backend,
|
||||
):
|
||||
set_random_seed(0)
|
||||
|
||||
@@ -270,6 +273,7 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": flashinfer_allreduce_backend,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -317,6 +321,10 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
results_unfused = model(hidden_states)
|
||||
results_fused = compiled_model(hidden_states)
|
||||
torch.testing.assert_close(results_unfused, results_fused, atol=1e-2, rtol=1e-2)
|
||||
|
||||
assert all_reduce_fusion_pass.matched_count == 4, (
|
||||
f"{all_reduce_fusion_pass.matched_count=}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user