[Kernel][Misc] register ops to prevent graph breaks (#6917)
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
import itertools
|
||||
import random
|
||||
from numbers import Number
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union
|
||||
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
|
||||
Union)
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -13,6 +14,21 @@ from vllm.attention.backends.xformers import XFormersBackend
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
|
||||
make_tensor_with_pad)
|
||||
|
||||
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
||||
# bugs related to this test in PyTorch 2.4.
|
||||
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
||||
"test_schema",
|
||||
"test_autograd_registration",
|
||||
"test_faketensor",
|
||||
)
|
||||
|
||||
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
||||
"test_schema",
|
||||
"test_autograd_registration",
|
||||
"test_faketensor",
|
||||
"test_aot_dispatch_dynamic",
|
||||
)
|
||||
|
||||
|
||||
class QKVInputs(NamedTuple):
|
||||
'''
|
||||
@@ -926,3 +942,19 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
||||
ideal_output = test_params.packed_qkvo.ideal_output
|
||||
torch.testing.assert_close(ideal_output,
|
||||
output_under_test.view_as(ideal_output))
|
||||
|
||||
|
||||
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
||||
torch._library.custom_ops.CustomOpDef],
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
||||
raise_exception: bool = True,
|
||||
cond: bool = True) -> Dict[str, str]:
|
||||
return torch.library.opcheck(
|
||||
op,
|
||||
args,
|
||||
kwargs,
|
||||
test_utils=test_utils,
|
||||
raise_exception=raise_exception) if cond else {}
|
||||
|
||||
Reference in New Issue
Block a user