[Kernel][Misc] register ops to prevent graph breaks (#6917)
Co-authored-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from typing import Optional, Tuple
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows, quantize_weights)
|
||||
@@ -76,6 +77,8 @@ def machete_quantize_and_pack(w: torch.Tensor,
|
||||
w_q = w_q.t().contiguous().t() # convert to col major
|
||||
w_q_machete = ops.machete_prepack_B(w_q, wtype)
|
||||
|
||||
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype))
|
||||
|
||||
return w_ref, w_q_machete, w_s, w_zp
|
||||
|
||||
|
||||
@@ -146,6 +149,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
|
||||
schedule=schedule,
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.machete_gemm,
|
||||
(a, w_q_machete, wtype, w_s, maybe_convert_zeropoints(
|
||||
w_zp, w_s), group_size, None, None, None, schedule))
|
||||
|
||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||
# Relax atol when we have zeropoints since the way machete applies
|
||||
# zeropoints (after scales) causes noise around 0
|
||||
|
||||
Reference in New Issue
Block a user