[Bugfix] Fix triton import with local TritonPlaceholder (#17446)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
92
tests/test_triton_utils.py
Normal file
92
tests/test_triton_utils.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import sys
|
||||
import types
|
||||
from unittest import mock
|
||||
|
||||
from vllm.triton_utils.importing import (TritonLanguagePlaceholder,
|
||||
TritonPlaceholder)
|
||||
|
||||
|
||||
def test_triton_placeholder_is_module():
|
||||
triton = TritonPlaceholder()
|
||||
assert isinstance(triton, types.ModuleType)
|
||||
assert triton.__name__ == "triton"
|
||||
|
||||
|
||||
def test_triton_language_placeholder_is_module():
|
||||
triton_language = TritonLanguagePlaceholder()
|
||||
assert isinstance(triton_language, types.ModuleType)
|
||||
assert triton_language.__name__ == "triton.language"
|
||||
|
||||
|
||||
def test_triton_placeholder_decorators():
|
||||
triton = TritonPlaceholder()
|
||||
|
||||
@triton.jit
|
||||
def foo(x):
|
||||
return x
|
||||
|
||||
@triton.autotune
|
||||
def bar(x):
|
||||
return x
|
||||
|
||||
@triton.heuristics
|
||||
def baz(x):
|
||||
return x
|
||||
|
||||
assert foo(1) == 1
|
||||
assert bar(2) == 2
|
||||
assert baz(3) == 3
|
||||
|
||||
|
||||
def test_triton_placeholder_decorators_with_args():
|
||||
triton = TritonPlaceholder()
|
||||
|
||||
@triton.jit(debug=True)
|
||||
def foo(x):
|
||||
return x
|
||||
|
||||
@triton.autotune(configs=[], key="x")
|
||||
def bar(x):
|
||||
return x
|
||||
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64})
|
||||
def baz(x):
|
||||
return x
|
||||
|
||||
assert foo(1) == 1
|
||||
assert bar(2) == 2
|
||||
assert baz(3) == 3
|
||||
|
||||
|
||||
def test_triton_placeholder_language():
|
||||
lang = TritonLanguagePlaceholder()
|
||||
assert isinstance(lang, types.ModuleType)
|
||||
assert lang.__name__ == "triton.language"
|
||||
assert lang.constexpr is None
|
||||
assert lang.dtype is None
|
||||
assert lang.int64 is None
|
||||
|
||||
|
||||
def test_triton_placeholder_language_from_parent():
|
||||
triton = TritonPlaceholder()
|
||||
lang = triton.language
|
||||
assert isinstance(lang, TritonLanguagePlaceholder)
|
||||
|
||||
|
||||
def test_no_triton_fallback():
|
||||
# clear existing triton modules
|
||||
sys.modules.pop("triton", None)
|
||||
sys.modules.pop("triton.language", None)
|
||||
sys.modules.pop("vllm.triton_utils", None)
|
||||
sys.modules.pop("vllm.triton_utils.importing", None)
|
||||
|
||||
# mock triton not being installed
|
||||
with mock.patch.dict(sys.modules, {"triton": None}):
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
assert HAS_TRITON is False
|
||||
assert triton.__class__.__name__ == "TritonPlaceholder"
|
||||
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
|
||||
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"
|
||||
Reference in New Issue
Block a user