[vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com> Signed-off-by: chzhang <chaojun.zhang@intel.com> Signed-off-by: Luka Govedic <luka.govedic@gmail.com> Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com> Co-authored-by: Chaojun Zhang <chaojun.zhang@intel.com> Co-authored-by: Luka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
This commit is contained in:
69
tests/compile/passes/ir/test_lowering.py
Normal file
69
tests/compile/passes/ir/test_lowering.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import vllm.kernels # noqa: F401 to register kernels
|
||||
from vllm import ir
|
||||
from vllm.compilation.passes.ir.lowering_pass import (
|
||||
VllmIRLoweringPass,
|
||||
)
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.ir import ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...backend import TestBackend
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, hidden_size=16, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hidden_size = hidden_size
|
||||
self.weight = torch.ones(hidden_size, dtype=torch.bfloat16)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = x + 4.0
|
||||
x2 = ops.rms_norm(x1, self.weight, 1e-5)
|
||||
x3 = x2 * 5.0
|
||||
# no weight
|
||||
x4 = ops.rms_norm(x3, None, 1e-5)
|
||||
x5 = x4 / 2.0
|
||||
# dispatch to native due to variance_size parameter
|
||||
x6 = ops.rms_norm(x5, self.weight, 1e-5, self.hidden_size // 2)
|
||||
return x6 + 3.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rms_provider", ops.rms_norm.supported_providers())
|
||||
def test_lowering_rms_norm(rms_provider, default_vllm_config):
|
||||
torch.set_default_device(current_platform.device_type)
|
||||
|
||||
lowering_pass = VllmIRLoweringPass(get_current_vllm_config())
|
||||
backend = TestBackend(lowering_pass)
|
||||
backend_unlowered = TestBackend()
|
||||
|
||||
model = Model()
|
||||
x = torch.randn(8, 16, dtype=torch.bfloat16)
|
||||
with (
|
||||
ops.rms_norm.set_priority([rms_provider, "native"]),
|
||||
ir.enable_torch_wrap(True),
|
||||
):
|
||||
compiled_model = torch.compile(model, backend=backend, fullgraph=True)
|
||||
compiled_unlowered_model = torch.compile(
|
||||
model, backend=backend_unlowered, fullgraph=True
|
||||
)
|
||||
output = compiled_model(x)
|
||||
output_unlowered = compiled_unlowered_model(x)
|
||||
|
||||
selected = lowering_pass.selected_impls["rms_norm"]
|
||||
assert len(selected) == 3
|
||||
assert selected["rms_norm"] == rms_provider
|
||||
assert selected["rms_norm_1"] == rms_provider
|
||||
assert selected["rms_norm_2"] == "native"
|
||||
|
||||
# Compiled function guards on global value, avoid recompilation
|
||||
with ir.enable_torch_wrap(True):
|
||||
output2 = compiled_model(x)
|
||||
|
||||
torch.testing.assert_close(output_unlowered, output)
|
||||
torch.testing.assert_close(output_unlowered, output2)
|
||||
Reference in New Issue
Block a user