[custom_op][vllm-plugin] update custom_op class to use op_registry (#19164)
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.plugins import load_general_plugins
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, STR_INVALID_VAL
|
||||
|
||||
|
||||
@@ -32,3 +33,16 @@ def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch):
|
||||
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
||||
backend = get_attn_backend(16, torch.float16, "auto", 16, False)
|
||||
assert backend.get_name() == "Dummy_Backend"
|
||||
|
||||
|
||||
def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch):
|
||||
# simulate workload by running an example
|
||||
load_general_plugins()
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
layer = RotaryEmbedding(16, 16, 16, 16, True, torch.float16)
|
||||
assert layer.__class__.__name__ == "DummyRotaryEmbedding", (
|
||||
f"Expected DummyRotaryEmbedding, got {layer.__class__.__name__}, "
|
||||
"possibly because the custom op is not registered correctly.")
|
||||
assert hasattr(layer, "addition_config"), (
|
||||
"Expected DummyRotaryEmbedding to have an 'addition_config' attribute, "
|
||||
"which is set by the custom op.")
|
||||
|
||||
Reference in New Issue
Block a user