[Kernel] some optimizations for dense marlin and moe marlin (#16850)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
Jinzhen Lin
2025-05-06 00:39:30 +08:00
committed by GitHub
parent f62cad6431
commit 1d0c9d6b2d
26 changed files with 3512 additions and 3268 deletions

View File

@@ -7,12 +7,15 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
from .quant_utils import pack_cols, unpack_cols
logger = init_logger(__name__)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
@@ -29,9 +32,11 @@ USE_FP32_REDUCE_DEFAULT = True
# For binary size and compile time, we don't support the same types for with and
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
def query_marlin_supported_quant_types(
has_zp: bool,
include_fp_type: bool = True,
device_capability: Optional[int] = None,
):
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
@@ -42,12 +47,13 @@ def query_marlin_supported_quant_types(has_zp: bool,
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4, scalar_types.uint8]
return [scalar_types.uint4]
else:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
# to add `scalar_types.float8_e4m3fn` here
return [scalar_types.uint4b8, scalar_types.uint8b128]
res = [scalar_types.uint4b8, scalar_types.uint8b128]
if include_fp_type:
res += [scalar_types.float8_e4m3fn]
return res
def _check_marlin_supported(
@@ -62,7 +68,7 @@ def _check_marlin_supported(
capability_tuple.to_int())
supported_types = query_marlin_supported_quant_types(
has_zp, device_capability)
has_zp, True, device_capability)
if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. "
@@ -175,6 +181,17 @@ def marlin_make_workspace(output_size_per_partition: int,
requires_grad=False)
def marlin_make_workspace_new(device: torch.device,
max_blocks_per_sm: int = 1) -> torch.Tensor:
# In the new marlin kernel, we use the num of threadblocks as workspace
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
sms = torch.cuda.get_device_properties(device).multi_processor_count
return torch.zeros(sms * max_blocks_per_sm,
dtype=torch.int,
device=device,
requires_grad=False)
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)
@@ -304,21 +321,50 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return output
def maybe_warn_marlin_atomic_add(device, dtype):
if torch.compiler.is_dynamo_compiling():
return
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
logger.info_once(
"You are running Marlin kernel with bf16 on GPUs before SM90. "
"You can consider change to fp16 to achieve better performance "
"if possible.")
def maybe_warn_marlin_atomic_add_env():
if torch.compiler.is_dynamo_compiling():
return
if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
return
logger.info_once(
"Marlin kernel can achieve better performance for small size_n "
"with experimental use_atomic_add feature. "
"You can consider set environment variable "
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
dtype: torch.dtype) -> bool:
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
if n >= 2048 or k < 2048 or device.type != "cuda":
return False
# disable atomicAdd reduce by default,
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda":
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
maybe_warn_marlin_atomic_add_env()
return False
# sm8x doesn't support atomicAdd + bfloat16 natively
device_capability = torch.cuda.get_device_capability(device)
if device_capability[0] < 9 and dtype == torch.bfloat16:
maybe_warn_marlin_atomic_add(device, dtype)
return False
# the performance of atomicAdd is better than global reduce
# only when m*n is small and k is large
return n < 2048 and k >= 2048
return True
def apply_gptq_marlin_linear(
@@ -332,7 +378,6 @@ def apply_gptq_marlin_linear(
wtype: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
has_zp: bool,
is_k_full: bool,
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
@@ -346,6 +391,7 @@ def apply_gptq_marlin_linear(
dtype=input.dtype)
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
weight_scale,
weight_zp,
@@ -358,7 +404,6 @@ def apply_gptq_marlin_linear(
size_k=input_size_per_partition,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)
@@ -391,6 +436,7 @@ def apply_awq_marlin_linear(
dtype=input.dtype)
output = ops.gptq_marlin_gemm(reshaped_x,
None,
weight,
weight_scale,
weight_zp,
@@ -401,8 +447,6 @@ def apply_awq_marlin_linear(
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=True,
has_zp=True,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False)