Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,11 +4,15 @@
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_awq_triton.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.awq_triton import (
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES,
|
||||
awq_dequantize_triton,
|
||||
awq_gemm_triton,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device = "cuda"
|
||||
@@ -33,23 +37,24 @@ def reverse_awq_order(t: torch.Tensor):
|
||||
# qweights - [R , C // 8], int32
|
||||
# scales - [R // G, C ], float16
|
||||
# zeros - [R // G, C // 8], int32
|
||||
def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
group_size: int) -> torch.Tensor:
|
||||
|
||||
def awq_dequantize_torch(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int
|
||||
) -> torch.Tensor:
|
||||
if group_size == -1:
|
||||
group_size = qweight.shape[0]
|
||||
|
||||
bits = 4
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None],
|
||||
shifts[None, None, :]).to(torch.int8)
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8
|
||||
)
|
||||
|
||||
iweights = iweights.view(iweights.shape[0], -1)
|
||||
|
||||
zeros = torch.bitwise_right_shift(qzeros[:, :, None],
|
||||
shifts[None, None, :]).to(torch.int8)
|
||||
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8
|
||||
)
|
||||
zeros = zeros.view(qzeros.shape[0], -1)
|
||||
zeros = reverse_awq_order(zeros)
|
||||
|
||||
@@ -70,7 +75,6 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
|
||||
@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
|
||||
def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
|
||||
if group_size == -1:
|
||||
group_size = qweight_rows
|
||||
|
||||
@@ -84,25 +88,27 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
qweight = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
dtype=qweight_dtype,
|
||||
device=device)
|
||||
scales = torch.rand(scales_rows,
|
||||
scales_cols,
|
||||
dtype=scales_dtype,
|
||||
device=device)
|
||||
zeros = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(zeros_rows, zeros_cols),
|
||||
dtype=zeros_dtype,
|
||||
device=device)
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
dtype=qweight_dtype,
|
||||
device=device,
|
||||
)
|
||||
scales = torch.rand(scales_rows, scales_cols, dtype=scales_dtype, device=device)
|
||||
zeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(zeros_rows, zeros_cols),
|
||||
dtype=zeros_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
iweights_triton = awq_dequantize_triton(qweight, scales, zeros)
|
||||
|
||||
assert (not torch.any(torch.isinf(iweights_triton))
|
||||
and not torch.any(torch.isnan(iweights_triton)))
|
||||
assert not torch.any(torch.isinf(iweights_triton)) and not torch.any(
|
||||
torch.isnan(iweights_triton)
|
||||
)
|
||||
|
||||
iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)
|
||||
|
||||
@@ -119,7 +125,6 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("splitK", [1, 8])
|
||||
def test_gemm(N, K, M, splitK, group_size):
|
||||
|
||||
if group_size == -1:
|
||||
group_size = K
|
||||
|
||||
@@ -138,35 +143,29 @@ def test_gemm(N, K, M, splitK, group_size):
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
input = torch.rand((input_rows, input_cols),
|
||||
dtype=input_dtype,
|
||||
device=device)
|
||||
qweight = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
device=device)
|
||||
qzeros = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qzeros_rows, qzeros_cols),
|
||||
device=device)
|
||||
scales = torch.rand((scales_rows, scales_cols),
|
||||
dtype=scales_dtype,
|
||||
device=device)
|
||||
input = torch.rand((input_rows, input_cols), dtype=input_dtype, device=device)
|
||||
qweight = torch.randint(
|
||||
0, torch.iinfo(torch.int32).max, (qweight_rows, qweight_cols), device=device
|
||||
)
|
||||
qzeros = torch.randint(
|
||||
0, torch.iinfo(torch.int32).max, (qzeros_rows, qzeros_cols), device=device
|
||||
)
|
||||
scales = torch.rand((scales_rows, scales_cols), dtype=scales_dtype, device=device)
|
||||
|
||||
output_triton = awq_gemm_triton(input, qweight, scales, qzeros,
|
||||
split_k_iters)
|
||||
output_triton = awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
|
||||
|
||||
assert (not torch.any(torch.isinf(output_triton))
|
||||
and not torch.any(torch.isnan(output_triton)))
|
||||
assert not torch.any(torch.isinf(output_triton)) and not torch.any(
|
||||
torch.isnan(output_triton)
|
||||
)
|
||||
|
||||
dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)
|
||||
|
||||
output_torch = torch.matmul(input, dequantized_weights)
|
||||
|
||||
assert (not torch.any(torch.isinf(output_torch))
|
||||
and not torch.any(torch.isnan(output_torch)))
|
||||
assert not torch.any(torch.isinf(output_torch)) and not torch.any(
|
||||
torch.isnan(output_torch)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output_triton.cpu(),
|
||||
output_torch.cpu(),
|
||||
atol=1e-1,
|
||||
rtol=1e-1)
|
||||
torch.testing.assert_close(
|
||||
output_triton.cpu(), output_torch.cpu(), atol=1e-1, rtol=1e-1
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user