[Kernel] optimize performance of gptq marlin kernel when n is small (#14138)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.platforms import current_platform
|
||||
@@ -290,6 +291,23 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
||||
return output
|
||||
|
||||
|
||||
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
|
||||
dtype: torch.dtype) -> bool:
|
||||
# disable atomicAdd reduce by default,
|
||||
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
|
||||
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda":
|
||||
return False
|
||||
|
||||
# sm8x doesn't support atomicAdd + bfloat16 natively
|
||||
device_capability = torch.cuda.get_device_capability(device)
|
||||
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
||||
return False
|
||||
|
||||
# the performance of atomicAdd is better than global reduce
|
||||
# only when m*n is small and k is large
|
||||
return max(m, 64) * n < 64 * 2048 and k >= 2048
|
||||
|
||||
|
||||
def apply_gptq_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
@@ -307,6 +325,12 @@ def apply_gptq_marlin_linear(
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
||||
|
||||
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
|
||||
n=output_size_per_partition,
|
||||
k=reshaped_x.size(1),
|
||||
device=input.device,
|
||||
dtype=input.dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
weight,
|
||||
weight_scale,
|
||||
@@ -320,6 +344,7 @@ def apply_gptq_marlin_linear(
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=is_k_full,
|
||||
has_zp=False,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
@@ -345,6 +370,12 @@ def apply_awq_marlin_linear(
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
||||
|
||||
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
|
||||
n=output_size_per_partition,
|
||||
k=reshaped_x.size(1),
|
||||
device=input.device,
|
||||
dtype=input.dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
weight,
|
||||
weight_scale,
|
||||
@@ -358,6 +389,7 @@ def apply_awq_marlin_linear(
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=True,
|
||||
has_zp=True,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user