[Kernel] FlashInfer: switch allreduce fusion to unified API (#33985)
Signed-off-by: Mohammad Miadh Angkad <176301910+mmangkad@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
cb62e86f83
commit
d4f123cc48
@@ -202,9 +202,10 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(
|
||||
not find_spec("flashinfer")
|
||||
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
|
||||
or not has_module_attribute("flashinfer.comm", "allreduce_fusion")
|
||||
or not has_module_attribute("flashinfer.comm", "create_allreduce_fusion_workspace"),
|
||||
reason="flashinfer is not found or flashinfer "
|
||||
"is not compiled with trtllm_allreduce_fusion",
|
||||
"is not compiled with allreduce_fusion",
|
||||
)
|
||||
def test_all_reduce_fusion_pass_replace(
|
||||
test_model: torch.nn.Module,
|
||||
|
||||
Reference in New Issue
Block a user