[6/N] torch.compile rollout to users (#10437)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-19 10:09:03 -08:00
committed by GitHub
parent fd9f124971
commit 803f37eaaa
15 changed files with 129 additions and 141 deletions

View File

@@ -2,7 +2,6 @@
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import os
import torch
from torch import nn
@@ -11,7 +10,7 @@ from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationLevel, VllmConfig
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op
@@ -77,12 +76,12 @@ class SillyModel(nn.Module):
def test_simple_piecewise_compile():
directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
cudagraph_copy_inputs=True,
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
@@ -109,6 +108,3 @@ def test_simple_piecewise_compile():
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
# clean up to avoid side effects for other tests
del os.environ["VLLM_TORCH_COMPILE_CONFIG"]