[Bugfix] Try to handle older versions of pytorch (#9086)

This commit is contained in:
bnellnm
2024-10-08 17:28:12 -04:00
committed by GitHub
parent de24046fcd
commit bd37b9fbe2
3 changed files with 41 additions and 21 deletions

View File

@@ -1,11 +1,14 @@
import os
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
reason="AWQ is not supported on this GPU type.")
def test_awq_dequantize_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
qweight = torch.randint(-2000000000,
@@ -21,6 +24,8 @@ def test_awq_dequantize_opcheck():
(qweight, scales, zeros, split_k_iters, thx, thy))
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
reason="AWQ is not supported on this GPU type.")
def test_awq_gemm_opcheck():
os.environ["VLLM_USE_TRITON_AWQ"] = "0"
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)