[6/N] torch.compile rollout to users (#10437)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
TEST_MODELS = [
|
||||
@@ -65,7 +65,6 @@ def check_full_graph_support(model,
|
||||
optimization_level,
|
||||
tp_size=1):
|
||||
# make sure these models can be captured in full graph mode
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
|
||||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
||||
|
||||
# The base meta llama uses too much memory.
|
||||
@@ -86,6 +85,7 @@ def check_full_graph_support(model,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
disable_custom_all_reduce=True,
|
||||
compilation_config=CompilationConfig(level=optimization_level),
|
||||
**model_kwargs)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
Reference in New Issue
Block a user