[torch.compile] rework compile control with piecewise cudagraph (#9715)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-10-29 23:03:49 -07:00
committed by GitHub
parent 7b0365efef
commit ff5ed6e1bc
17 changed files with 979 additions and 102 deletions

View File

@@ -9,17 +9,19 @@ from vllm.platforms import current_platform
TEST_MODELS = [
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
"dtype": torch.float16,
"quantization": "compressed-tensors"
}),
# TODO: add fake implementation for compressed-tensors
# ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
# "dtype": torch.float16,
# "quantization": "compressed-tensors"
# }),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", {
"dtype": torch.float16,
"quantization": "fp8"
}),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
# TODO: add fake implementation for compressed-tensors
# ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
# "quantization": "compressed-tensors"
# }),
("meta-llama/Meta-Llama-3-8B", {}),
]
@@ -73,7 +75,7 @@ def check_full_graph_support(model,
# much memory.
quantization = model_kwargs.get("quantization")
if ((quantization == "fp8" or model == "meta-llama/Meta-Llama-3-8B")
and optimization_level >= CompilationLevel.INDUCTOR):
and optimization_level >= CompilationLevel.PIECEWISE):
return
prompts = [