[torch.compile] integration with compilation control (#9058)
This commit is contained in:
@@ -4,16 +4,9 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.plugins import set_torch_compile_backend
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.utils import is_hip
|
||||
|
||||
TEST_MODELS_SMOKE = [
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
|
||||
"quantization": "compressed-tensors"
|
||||
}),
|
||||
("meta-llama/Meta-Llama-3-8B", {}),
|
||||
]
|
||||
|
||||
TEST_MODELS = [
|
||||
("facebook/opt-125m", {}),
|
||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
||||
@@ -68,20 +61,21 @@ if not is_hip() and is_quant_method_supported("awq"):
|
||||
}))
|
||||
|
||||
|
||||
def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
|
||||
def check_full_graph_support(model,
|
||||
model_kwargs,
|
||||
optimization_level,
|
||||
tp_size=1):
|
||||
# make sure these models can be captured in full graph mode
|
||||
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
|
||||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
|
||||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
|
||||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
||||
|
||||
# Inductor doesn't support fp8/gptq_marlin_24 yet.
|
||||
quantization = model_kwargs.get("quantization")
|
||||
if (quantization == "fp8" or quantization == "gptq_marlin"
|
||||
or quantization == "gptq_marlin_24") and backend != "eager":
|
||||
or quantization == "gptq_marlin_24"
|
||||
) and optimization_level >= CompilationLevel.INDUCTOR:
|
||||
return
|
||||
|
||||
set_torch_compile_backend(backend)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
|
||||
Reference in New Issue
Block a user