[Bugfix] Try to handle older versions of pytorch (#9086)
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
|
||||
reason="AWQ is not supported on this GPU type.")
|
||||
def test_awq_dequantize_opcheck():
|
||||
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
||||
qweight = torch.randint(-2000000000,
|
||||
@@ -21,6 +24,8 @@ def test_awq_dequantize_opcheck():
|
||||
(qweight, scales, zeros, split_k_iters, thx, thy))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
|
||||
reason="AWQ is not supported on this GPU type.")
|
||||
def test_awq_gemm_opcheck():
|
||||
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
|
||||
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
|
||||
|
||||
Reference in New Issue
Block a user