[6/N] torch.compile rollout to users (#10437)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
Test the piecewise compilation with a simple model so that we
|
||||
can exactly calculate the expected output and side effects.
|
||||
"""
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -11,7 +10,7 @@ from torch.library import Library
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
from vllm.plugins import set_current_vllm_config
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@@ -77,12 +76,12 @@ class SillyModel(nn.Module):
|
||||
|
||||
def test_simple_piecewise_compile():
|
||||
|
||||
directory = os.path.dirname(__file__)
|
||||
config = os.path.join(directory, "piecewise_compilation_config.json")
|
||||
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
non_cudagraph_ops=["silly.attention"],
|
||||
cudagraph_copy_inputs=True,
|
||||
))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SillyModel(vllm_config=vllm_config, prefix='')
|
||||
|
||||
@@ -109,6 +108,3 @@ def test_simple_piecewise_compile():
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
|
||||
# clean up to avoid side effects for other tests
|
||||
del os.environ["VLLM_TORCH_COMPILE_CONFIG"]
|
||||
|
||||
Reference in New Issue
Block a user