[Bugfix] adding chunking mechanism to fused_moe to handle large inputs (#6029)

This commit is contained in:
Avshalom Manevich
2024-07-02 00:08:29 +03:00
committed by GitHub
parent dec6fc6f3b
commit 12a59959ed
3 changed files with 74 additions and 48 deletions

View File

@@ -29,7 +29,7 @@ def torch_moe(a, w1, w2, score, topk):
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [512, 222, 33, 1])
@pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("e", [8, 64])