[Kernel] optimize performance of gptq marlin kernel when n is small (#14138)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin
2025-03-08 00:53:38 +08:00
committed by GitHub
parent 58abe35455
commit d0feea31c7
6 changed files with 99 additions and 24 deletions

View File

@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple
import numpy
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase
from vllm.platforms import current_platform
@@ -290,6 +291,23 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return output
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
dtype: torch.dtype) -> bool:
# disable atomicAdd reduce by default,
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda":
return False
# sm8x doesn't support atomicAdd + bfloat16 natively
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
return False
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
return max(m, 64) * n < 64 * 2048 and k >= 2048
def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
@@ -307,6 +325,12 @@ def apply_gptq_marlin_linear(
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype)
output = ops.gptq_marlin_gemm(reshaped_x,
weight,
weight_scale,
@@ -320,6 +344,7 @@ def apply_gptq_marlin_linear(
size_k=input_size_per_partition,
is_k_full=is_k_full,
has_zp=False,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
@@ -345,6 +370,12 @@ def apply_awq_marlin_linear(
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype)
output = ops.gptq_marlin_gemm(reshaped_x,
weight,
weight_scale,
@@ -358,6 +389,7 @@ def apply_awq_marlin_linear(
size_k=input_size_per_partition,
is_k_full=True,
has_zp=True,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)