[torch.compile] support moe models (#9632)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-10-27 21:58:04 -07:00
committed by GitHub
parent 4e2d95e372
commit 32176fee73
12 changed files with 216 additions and 77 deletions

View File

@@ -13,11 +13,11 @@ from ..utils import compare_all_settings
@pytest.mark.parametrize(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
[
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASH_ATTN", "generate", True),
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
["--quantization", "compressed-tensors"
], 1, 1, "FLASH_ATTN", "generate", True),
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True),
# TODO: add multi-modality test for llava
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
])