[1/N] torch.compile user interface design (#10237)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-11 18:01:06 -08:00
committed by GitHub
parent 9cdba9669c
commit eea55cca5b
4 changed files with 55 additions and 37 deletions

View File

@@ -12,10 +12,9 @@ from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.utils import direct_register_custom_op
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
global_counter = 0
# create a library to hold the custom op
@@ -48,7 +47,11 @@ direct_register_custom_op(
@support_torch_compile
class SillyModel(nn.Module):
def __init__(self) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -74,11 +77,12 @@ class SillyModel(nn.Module):
def test_simple_piecewise_compile():
model = SillyModel()
directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
model = SillyModel(vllm_config=VllmConfig(), prefix='')
inputs = torch.randn(100).cuda()