[vLLM IR] 1/N Implement IR skeleton and rms_norm op (#33825)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com> Signed-off-by: chzhang <chaojun.zhang@intel.com> Signed-off-by: Luka Govedic <luka.govedic@gmail.com> Co-authored-by: Xinyu Chen <xinyu1.chen@intel.com> Co-authored-by: Chaojun Zhang <chaojun.zhang@intel.com> Co-authored-by: Luka Govedič <ProExpertProg@h100-01.nemg-001.lab.rdu2.dc.redhat.com>
This commit is contained in:
497
tests/ir/test_op.py
Normal file
497
tests/ir/test_op.py
Normal file
@@ -0,0 +1,497 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib.util
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
import vllm.ir.op
|
||||
from vllm.ir.op import RESERVED_PROVIDERS, IrOp, IrOpImpl
|
||||
|
||||
# This should not exist
|
||||
assert "_custom_add" not in IrOp.registry
|
||||
|
||||
|
||||
class CustomError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@vllm.ir.register_op
|
||||
def _custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y
|
||||
|
||||
|
||||
def test_registration_overloads():
|
||||
assert all(
|
||||
n not in IrOp.registry for n in ["_custom_sub", "_custom_mul", "_custom_div"]
|
||||
)
|
||||
|
||||
# Calling with decorator
|
||||
@vllm.ir.register_op()
|
||||
def _custom_sub(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x - y
|
||||
|
||||
assert _custom_sub.name == "_custom_sub"
|
||||
assert _custom_sub is IrOp.registry["_custom_sub"]
|
||||
|
||||
# Custom name
|
||||
@vllm.ir.register_op(name="_custom_mul")
|
||||
def custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x * y
|
||||
|
||||
assert custom_mul.name == "_custom_mul"
|
||||
assert custom_mul is IrOp.registry["_custom_mul"]
|
||||
|
||||
# Direct construction does not register directly
|
||||
def _custom_div(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x / y
|
||||
|
||||
custom_div = IrOp("_custom_div", _custom_div)
|
||||
assert custom_div.name == "_custom_div"
|
||||
assert "_custom_div" not in IrOp.registry
|
||||
|
||||
# Duplicate op registration not allowed
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
@vllm.ir.register_op
|
||||
def _custom_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x * y - 100
|
||||
|
||||
|
||||
def test_no_kw_only_args():
|
||||
# kw-only args not supported
|
||||
with pytest.raises(ValueError, match="keyword-only arguments"):
|
||||
|
||||
@vllm.ir.register_op
|
||||
def _custom_kwarg_op(
|
||||
x: torch.Tensor, y: torch.Tensor, *, kwarg: int = 0
|
||||
) -> torch.Tensor:
|
||||
return x + y + kwarg
|
||||
|
||||
assert "_custom_kwarg_op" not in IrOp.registry
|
||||
|
||||
|
||||
class TestIrOpCustomAdd:
|
||||
# Registration invariants
|
||||
def test_decorated_object(self):
|
||||
"""Make sure that referring directly to an op is correct"""
|
||||
assert isinstance(_custom_add, IrOp)
|
||||
assert "_custom_add" in IrOp.registry
|
||||
assert _custom_add is IrOp.registry["_custom_add"]
|
||||
|
||||
def test_torch_op_is_registered(self):
|
||||
assert hasattr(torch.ops.vllm_ir, "_custom_add")
|
||||
assert callable(torch.ops.vllm_ir._custom_add.default)
|
||||
|
||||
# Semantic correctness
|
||||
def test_semantics_match_native(self):
|
||||
x = torch.randn(4, 5)
|
||||
y = torch.randn(4, 5)
|
||||
|
||||
# Calls native by default
|
||||
out = _custom_add(x, y)
|
||||
ref = x + y
|
||||
|
||||
torch.testing.assert_close(out, ref)
|
||||
|
||||
# -------------------------
|
||||
# Implementation registration
|
||||
# -------------------------
|
||||
|
||||
def test_register_impl_is_non_intrusive(self):
|
||||
@_custom_add.register_impl("dummy_provider")
|
||||
def dummy_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + 123
|
||||
|
||||
assert "dummy_provider" in _custom_add.impls
|
||||
assert isinstance(_custom_add.impls["dummy_provider"], IrOpImpl)
|
||||
|
||||
x = torch.ones(2, 2)
|
||||
y = torch.ones(2, 2)
|
||||
|
||||
# Native semantics must still hold
|
||||
torch.testing.assert_close(_custom_add(x, y), x + y)
|
||||
|
||||
def test_schema_contains_tensor_signature(self):
|
||||
schema = _custom_add._schema_str
|
||||
|
||||
assert "Tensor" in schema
|
||||
assert "-> Tensor" in schema
|
||||
|
||||
# -------------------------
|
||||
# FX visibility
|
||||
# -------------------------
|
||||
|
||||
@pytest.mark.parametrize("enable_torch_wrap", [True, False])
|
||||
@pytest.mark.parametrize("symbolic_trace", [True, False])
|
||||
def test_trace_sees_single_custom_op(
|
||||
self, symbolic_trace: bool, enable_torch_wrap: bool
|
||||
):
|
||||
def fn(x, y):
|
||||
return _custom_add(x, y)
|
||||
|
||||
def find_fn(target: Any, gm: fx.GraphModule):
|
||||
return gm.graph.find_nodes(op="call_function", target=target)
|
||||
|
||||
with pytest.raises(CustomError), vllm.ir.enable_torch_wrap(enable_torch_wrap):
|
||||
if symbolic_trace:
|
||||
gm = torch.fx.symbolic_trace(fn)
|
||||
else:
|
||||
gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))
|
||||
|
||||
x1, y1 = torch.rand(5, 4), torch.rand(5, 4)
|
||||
out_fx = gm(x1, y1)
|
||||
out_eager = fn(x1, y1)
|
||||
|
||||
# raise error to check enable_torch_wrap context restored correctly
|
||||
raise CustomError
|
||||
|
||||
# check behavior matches eager in all cases
|
||||
torch.testing.assert_close(out_fx, out_eager)
|
||||
|
||||
# check that IR nodes only appear if enable_torch_wrap=True
|
||||
ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
|
||||
if enable_torch_wrap:
|
||||
assert len(ir_nodes) == 1, gm.code
|
||||
else:
|
||||
assert len(ir_nodes) == 0, gm.code
|
||||
|
||||
# with torch wrapping enabled (default), IR nodes appear
|
||||
if symbolic_trace:
|
||||
gm = torch.fx.symbolic_trace(fn)
|
||||
else:
|
||||
gm = make_fx(fn)(torch.randn(2, 2), torch.randn(2, 2))
|
||||
|
||||
ir_nodes = find_fn(torch.ops.vllm_ir._custom_add.default, gm)
|
||||
assert len(ir_nodes) == 1, gm.code
|
||||
|
||||
|
||||
@_custom_add.register_impl("impl_a")
|
||||
def impl_a(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + 10
|
||||
|
||||
|
||||
@_custom_add.register_impl("impl_b")
|
||||
def impl_b(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + 20
|
||||
|
||||
|
||||
@_custom_add.register_impl("impl_even", supports_args=lambda x, y: x.size(1) % 2 == 0)
|
||||
def impl_even(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + 50
|
||||
|
||||
|
||||
class TestIrOpImplDispatch:
|
||||
def test_register_impl(self):
|
||||
assert "impl_a" in _custom_add.impls
|
||||
impl = _custom_add.impls["impl_a"]
|
||||
|
||||
assert impl is impl_a
|
||||
assert impl.op is _custom_add
|
||||
assert impl.provider == "impl_a"
|
||||
assert callable(impl.impl_fn)
|
||||
|
||||
# Test duplicate registration rejected
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
@_custom_add.register_impl("impl_a")
|
||||
def impl_a_dup(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + 30
|
||||
|
||||
# Check the original impl is still intact
|
||||
assert _custom_add.impls["impl_a"] is impl_a
|
||||
|
||||
# Check support all args
|
||||
assert impl_a.supports_all_args
|
||||
assert impl_b.supports_all_args
|
||||
assert not impl_even.supports_all_args
|
||||
|
||||
def test_reserved_provider_rejected(self):
|
||||
for provider in RESERVED_PROVIDERS:
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
@_custom_add.register_impl(provider)
|
||||
def bad_impl(x, y):
|
||||
return x + y
|
||||
|
||||
def test_set_priority_scoped(self):
|
||||
assert _custom_add.get_priority() == []
|
||||
|
||||
with _custom_add.set_priority(["impl_even", "impl_b"]):
|
||||
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
|
||||
|
||||
# Check nesting
|
||||
with _custom_add.set_priority(["impl_b"]):
|
||||
assert _custom_add.get_priority() == ["impl_b"]
|
||||
|
||||
# Restored
|
||||
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
|
||||
|
||||
# Check that exception restores priority
|
||||
with pytest.raises(CustomError), _custom_add.set_priority(["impl_a"]):
|
||||
assert _custom_add.get_priority() == ["impl_a"]
|
||||
raise CustomError
|
||||
|
||||
# Restored again
|
||||
assert _custom_add.get_priority() == ["impl_even", "impl_b"]
|
||||
|
||||
# Restored to empty
|
||||
assert _custom_add.get_priority() == []
|
||||
|
||||
def test_dispatch_priority_order(self):
|
||||
x = torch.tensor(1, dtype=torch.int32)
|
||||
y = torch.tensor(2, dtype=torch.int32)
|
||||
|
||||
with _custom_add.set_priority(["impl_b", "impl_a"]):
|
||||
assert _custom_add.dispatch(x, y) is impl_b
|
||||
out1 = _custom_add(x, y)
|
||||
out2 = torch.ops.vllm_ir._custom_add(x, y)
|
||||
|
||||
with _custom_add.set_priority(["impl_a"]):
|
||||
assert _custom_add.dispatch(x, y) is impl_a
|
||||
out3 = _custom_add(x, y)
|
||||
out4 = torch.ops.vllm_ir._custom_add(x, y)
|
||||
|
||||
# impl_b
|
||||
assert out1.item() == 1 + 2 + 20
|
||||
assert out2.item() == 1 + 2 + 20
|
||||
# impl_a
|
||||
assert out3.item() == 1 + 2 + 10
|
||||
assert out4.item() == 1 + 2 + 10
|
||||
|
||||
def test_unsupported_impl_filtered(self):
|
||||
@_custom_add.register_impl("unsupported", supported=False)
|
||||
def impl_bad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + 999
|
||||
|
||||
x = torch.tensor(1, dtype=torch.int32)
|
||||
y = torch.tensor(2, dtype=torch.int32)
|
||||
|
||||
with _custom_add.set_priority(["unsupported", "impl_a"]):
|
||||
assert _custom_add.get_priority() == ["impl_a"]
|
||||
out = _custom_add(x, y)
|
||||
|
||||
# impl_bad skipped → impl_a
|
||||
assert out.item() == 1 + 2 + 10
|
||||
|
||||
def test_supports_args_runtime_dispatch_and_warning(
|
||||
self, caplog_vllm: pytest.LogCaptureFixture
|
||||
):
|
||||
x1 = torch.ones((2, 2), dtype=torch.int32)
|
||||
y1 = torch.full((2, 2), 2, dtype=torch.int32)
|
||||
|
||||
x2 = torch.ones((2, 3), dtype=torch.int32)
|
||||
y2 = torch.full((2, 3), 2, dtype=torch.int32)
|
||||
|
||||
with (
|
||||
caplog_vllm.at_level(logging.WARNING),
|
||||
_custom_add.set_priority(["impl_even"]),
|
||||
):
|
||||
# Test the warning about native fallback is logged (before even dispatching)
|
||||
assert len(caplog_vllm.records) == 1
|
||||
message = caplog_vllm.records[0].message
|
||||
assert "_custom_add" in message
|
||||
assert "fallback to native" in message
|
||||
assert "priority" in message
|
||||
|
||||
# Check dispatching
|
||||
assert _custom_add.get_priority() == ["impl_even", "native"]
|
||||
assert _custom_add.dispatch(x1, y1) is impl_even
|
||||
assert _custom_add.dispatch(x2, y2) is _custom_add.impls["native"]
|
||||
|
||||
out1 = _custom_add(x1, y1) # size(1) == 2 → impl_even
|
||||
out2 = _custom_add(x2, y2) # size(1) == 3 → native fallback
|
||||
|
||||
# no other warnings
|
||||
assert len(caplog_vllm.records) == 1
|
||||
assert torch.all(out1 == 1 + 2 + 50)
|
||||
assert torch.all(out2 == 1 + 2)
|
||||
|
||||
def test_default_priority(
|
||||
self, caplog_vllm: pytest.LogCaptureFixture, disable_log_dedup
|
||||
):
|
||||
# Make sure logs are not deduplicated to properly test the warning
|
||||
x = torch.tensor([3], dtype=torch.int32)
|
||||
y = torch.tensor([4], dtype=torch.int32)
|
||||
|
||||
# No priority set → falls back to native
|
||||
assert _custom_add.get_priority() == []
|
||||
with caplog_vllm.at_level(logging.WARNING):
|
||||
# Native by default
|
||||
assert _custom_add.dispatch(x, y) is _custom_add.impls["native"]
|
||||
out = _custom_add(x, y)
|
||||
|
||||
# Check dispatching to native by default
|
||||
assert out.item() == 3 + 4
|
||||
|
||||
# Check warning
|
||||
assert len(caplog_vllm.records) == 2
|
||||
message = caplog_vllm.records[0].message.lower()
|
||||
assert "_custom_add" in message
|
||||
assert "priority not set" in message
|
||||
|
||||
|
||||
@vllm.ir.register_op
|
||||
def _custom_mm(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
tmp = x @ y
|
||||
return tmp if bias is None else tmp + bias
|
||||
|
||||
|
||||
def test_default_args():
|
||||
# Test that default args are properly applied when dispatching and calling
|
||||
@_custom_mm.register_impl("impl_mm", supports_args=lambda x, y, bias=None: True)
|
||||
def impl_mm(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
tmp = x @ y
|
||||
return tmp + 50 if bias is None else tmp + bias + 100
|
||||
|
||||
x1 = torch.tensor([1, 2], dtype=torch.int32)
|
||||
x2 = torch.tensor([3, 4], dtype=torch.int32)
|
||||
|
||||
# Test that supports_args receives the defaulted args
|
||||
assert impl_mm.supports_args(x1, x2)
|
||||
with _custom_mm.set_priority(["impl_mm", "native"]):
|
||||
assert _custom_mm.dispatch(x1, x2) is impl_mm
|
||||
|
||||
|
||||
def test_bad_impl_registrations():
|
||||
# Check bad schema
|
||||
with pytest.raises(ValueError, match="does not match native schema"):
|
||||
|
||||
@_custom_mm.register_impl("impl_mm_bad_schema")
|
||||
def impl_mm_bad_schema(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x @ y - 1
|
||||
|
||||
with pytest.raises(ValueError, match="does not match native schema"):
|
||||
|
||||
@_custom_mm.register_impl("impl_mm_bad_schema_2")
|
||||
def impl_mm_bad_schema_2(
|
||||
x: torch.Tensor, y: torch.Tensor, b: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y + b - 2
|
||||
|
||||
with pytest.raises(ValueError, match="does not match native schema"):
|
||||
|
||||
@_custom_mm.register_impl("impl_mm_bad_schema_3")
|
||||
def impl_mm_bad_schema_3(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
return x @ y + bias - 5
|
||||
|
||||
# check supports_args with incorrect params
|
||||
with pytest.raises(ValueError, match="supports_args must be a callable"):
|
||||
|
||||
@_custom_mm.register_impl("impl_mm_bad_supports_args", supports_args=True)
|
||||
def impl_mm_bad_supports_args(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y + 10
|
||||
|
||||
with pytest.raises(ValueError, match="number of parameters"):
|
||||
|
||||
@_custom_mm.register_impl(
|
||||
"impl_mm_bad_supports_args_2", supports_args=lambda x, y: True
|
||||
)
|
||||
def impl_mm_bad_supports_args(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y + 10
|
||||
|
||||
with pytest.raises(ValueError, match="keyword-only parameters"):
|
||||
|
||||
@_custom_mm.register_impl(
|
||||
"impl_mm_bad_supports_args_3", supports_args=lambda x, y, *, b: True
|
||||
)
|
||||
def impl_mm_bad_supports_args_2(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y + 20
|
||||
|
||||
with pytest.raises(ValueError, match="does not match native parameter"):
|
||||
|
||||
@_custom_mm.register_impl(
|
||||
"impl_mm_bad_supports_args_4", supports_args=lambda x, y, b: True
|
||||
)
|
||||
def impl_mm_bad_supports_args_4(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y + 30
|
||||
|
||||
with pytest.raises(ValueError, match="does not match native default"):
|
||||
|
||||
@_custom_mm.register_impl(
|
||||
"impl_mm_bad_supports_args_5", supports_args=lambda x, y, bias=1: True
|
||||
)
|
||||
def impl_mm_bad_supports_args_5(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y + 40
|
||||
|
||||
assert set(_custom_mm.impls.keys()) == {"impl_mm", "native"}
|
||||
|
||||
|
||||
IMPL_OOT_SRC = """
|
||||
import torch
|
||||
|
||||
@_custom_mm.register_impl("impl_mm_oot")
|
||||
def impl_mm_oot(
|
||||
x: torch.Tensor, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return x @ y - 99
|
||||
"""
|
||||
|
||||
|
||||
def load_custom_mm_module(file_path: Path):
|
||||
spec = importlib.util.spec_from_file_location("_custom_mm_oot", file_path)
|
||||
assert spec is not None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
|
||||
# Inject the variable into the module's global namespace
|
||||
# This allows the @_custom_mm.register_impl decorator to work
|
||||
module._custom_mm = _custom_mm # type: ignore[attr-defined]
|
||||
|
||||
# Execute the file; this triggers the decorator
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_uuid_and_oot(tmp_path: Path):
|
||||
file_path = tmp_path / "_custom_mm_oot.py"
|
||||
file_path.write_text(IMPL_OOT_SRC)
|
||||
|
||||
assert "impl_mm_oot" not in _custom_mm.impls
|
||||
_ = load_custom_mm_module(file_path)
|
||||
assert "impl_mm_oot" in _custom_mm.impls
|
||||
|
||||
uuid = _custom_mm.impls["impl_mm_oot"].uuid()
|
||||
del _custom_mm.impls["impl_mm_oot"]
|
||||
|
||||
# Replace file source
|
||||
file_path.write_text(IMPL_OOT_SRC + " # added file source")
|
||||
assert "impl_mm_oot" not in _custom_mm.impls
|
||||
_ = load_custom_mm_module(file_path)
|
||||
assert "impl_mm_oot" in _custom_mm.impls
|
||||
|
||||
uuid1 = _custom_mm.impls["impl_mm_oot"].uuid()
|
||||
assert uuid1 != uuid
|
||||
del _custom_mm.impls["impl_mm_oot"]
|
||||
|
||||
# Back to original
|
||||
file_path.write_text(IMPL_OOT_SRC)
|
||||
assert "impl_mm_oot" not in _custom_mm.impls
|
||||
_ = load_custom_mm_module(file_path)
|
||||
assert "impl_mm_oot" in _custom_mm.impls
|
||||
|
||||
uuid2 = _custom_mm.impls["impl_mm_oot"].uuid()
|
||||
assert uuid2 == uuid
|
||||
assert uuid2 != uuid1
|
||||
del _custom_mm.impls["impl_mm_oot"]
|
||||
Reference in New Issue
Block a user