[1/N] torch.compile user interface design (#10237)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -12,10 +12,9 @@ from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||
|
||||
global_counter = 0
|
||||
|
||||
# create a library to hold the custom op
|
||||
@@ -48,7 +47,11 @@ direct_register_custom_op(
|
||||
@support_torch_compile
|
||||
class SillyModel(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -74,11 +77,12 @@ class SillyModel(nn.Module):
|
||||
|
||||
def test_simple_piecewise_compile():
|
||||
|
||||
model = SillyModel()
|
||||
|
||||
directory = os.path.dirname(__file__)
|
||||
config = os.path.join(directory, "piecewise_compilation_config.json")
|
||||
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
|
||||
|
||||
model = SillyModel(vllm_config=VllmConfig(), prefix='')
|
||||
|
||||
inputs = torch.randn(100).cuda()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user