Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com> Signed-off-by: chzhang <chaojun.zhang@intel.com> Signed-off-by: Luka Govedic <luka.govedic@gmail.com> Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com> Co-authored-by: Chaojun Zhang <chaojun.zhang@intel.com> Co-authored-by: Luka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
70 lines
2.3 KiB
Python
70 lines
2.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import pytest
|
|
import torch
|
|
from torch import nn
|
|
|
|
import vllm.kernels # noqa: F401 to register kernels
|
|
from vllm import ir
|
|
from vllm.compilation.passes.ir.lowering_pass import (
|
|
VllmIRLoweringPass,
|
|
)
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.ir import ops
|
|
from vllm.platforms import current_platform
|
|
|
|
from ...backend import TestBackend
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, hidden_size=16, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.hidden_size = hidden_size
|
|
self.weight = torch.ones(hidden_size, dtype=torch.bfloat16)
|
|
|
|
def forward(self, x):
|
|
x1 = x + 4.0
|
|
x2 = ops.rms_norm(x1, self.weight, 1e-5)
|
|
x3 = x2 * 5.0
|
|
# no weight
|
|
x4 = ops.rms_norm(x3, None, 1e-5)
|
|
x5 = x4 / 2.0
|
|
# dispatch to native due to variance_size parameter
|
|
x6 = ops.rms_norm(x5, self.weight, 1e-5, self.hidden_size // 2)
|
|
return x6 + 3.0
|
|
|
|
|
|
@pytest.mark.parametrize("rms_provider", ops.rms_norm.supported_providers())
|
|
def test_lowering_rms_norm(rms_provider, default_vllm_config):
|
|
torch.set_default_device(current_platform.device_type)
|
|
|
|
lowering_pass = VllmIRLoweringPass(get_current_vllm_config())
|
|
backend = TestBackend(lowering_pass)
|
|
backend_unlowered = TestBackend()
|
|
|
|
model = Model()
|
|
x = torch.randn(8, 16, dtype=torch.bfloat16)
|
|
with (
|
|
ops.rms_norm.set_priority([rms_provider, "native"]),
|
|
ir.enable_torch_wrap(True),
|
|
):
|
|
compiled_model = torch.compile(model, backend=backend, fullgraph=True)
|
|
compiled_unlowered_model = torch.compile(
|
|
model, backend=backend_unlowered, fullgraph=True
|
|
)
|
|
output = compiled_model(x)
|
|
output_unlowered = compiled_unlowered_model(x)
|
|
|
|
selected = lowering_pass.selected_impls["rms_norm"]
|
|
assert len(selected) == 3
|
|
assert selected["rms_norm"] == rms_provider
|
|
assert selected["rms_norm_1"] == rms_provider
|
|
assert selected["rms_norm_2"] == "native"
|
|
|
|
# Compiled function guards on global value, avoid recompilation
|
|
with ir.enable_torch_wrap(True):
|
|
output2 = compiled_model(x)
|
|
|
|
torch.testing.assert_close(output_unlowered, output)
|
|
torch.testing.assert_close(output_unlowered, output2)
|