[Kernel][Misc] register ops to prevent graph breaks (#6917)
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import Optional, Type
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -108,6 +109,9 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm,
|
||||
(out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
|
||||
@pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
|
||||
@@ -341,6 +345,15 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
|
||||
|
||||
if azp_per_token:
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
|
||||
func_bias))
|
||||
else:
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
|
||||
func_bias))
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
def test_cutlass_subset():
|
||||
|
||||
Reference in New Issue
Block a user