[Kernel] some optimizations for dense marlin and moe marlin (#16850)
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
This commit is contained in:
@@ -7,12 +7,15 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
from .quant_utils import pack_cols, unpack_cols
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||
@@ -29,9 +32,11 @@ USE_FP32_REDUCE_DEFAULT = True
|
||||
# For binary size and compile time, we don't support the same types for with and
|
||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||
def query_marlin_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
def query_marlin_supported_quant_types(
|
||||
has_zp: bool,
|
||||
include_fp_type: bool = True,
|
||||
device_capability: Optional[int] = None,
|
||||
):
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
@@ -42,12 +47,13 @@ def query_marlin_supported_quant_types(has_zp: bool,
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
return [scalar_types.uint4]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
# TODO: once fp8_marlin is merged into "gptq_marlin" we should be able
|
||||
# to add `scalar_types.float8_e4m3fn` here
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
res = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
if include_fp_type:
|
||||
res += [scalar_types.float8_e4m3fn]
|
||||
return res
|
||||
|
||||
|
||||
def _check_marlin_supported(
|
||||
@@ -62,7 +68,7 @@ def _check_marlin_supported(
|
||||
capability_tuple.to_int())
|
||||
|
||||
supported_types = query_marlin_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
has_zp, True, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Marlin does not support weight_bits = {quant_type}. "
|
||||
@@ -175,6 +181,17 @@ def marlin_make_workspace(output_size_per_partition: int,
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def marlin_make_workspace_new(device: torch.device,
|
||||
max_blocks_per_sm: int = 1) -> torch.Tensor:
|
||||
# In the new marlin kernel, we use the num of threadblocks as workspace
|
||||
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
return torch.zeros(sms * max_blocks_per_sm,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
||||
return (not act_order) or (act_order and not is_row_parallel)
|
||||
|
||||
@@ -304,21 +321,50 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
|
||||
return output
|
||||
|
||||
|
||||
def maybe_warn_marlin_atomic_add(device, dtype):
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return
|
||||
device_capability = torch.cuda.get_device_capability(device)
|
||||
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
||||
logger.info_once(
|
||||
"You are running Marlin kernel with bf16 on GPUs before SM90. "
|
||||
"You can consider change to fp16 to achieve better performance "
|
||||
"if possible.")
|
||||
|
||||
|
||||
def maybe_warn_marlin_atomic_add_env():
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return
|
||||
if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
|
||||
return
|
||||
logger.info_once(
|
||||
"Marlin kernel can achieve better performance for small size_n "
|
||||
"with experimental use_atomic_add feature. "
|
||||
"You can consider set environment variable "
|
||||
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.")
|
||||
|
||||
|
||||
def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device,
|
||||
dtype: torch.dtype) -> bool:
|
||||
|
||||
# the performance of atomicAdd is better than global reduce
|
||||
# only when m*n is small and k is large
|
||||
if n >= 2048 or k < 2048 or device.type != "cuda":
|
||||
return False
|
||||
|
||||
# disable atomicAdd reduce by default,
|
||||
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
|
||||
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda":
|
||||
if not envs.VLLM_MARLIN_USE_ATOMIC_ADD:
|
||||
maybe_warn_marlin_atomic_add_env()
|
||||
return False
|
||||
|
||||
# sm8x doesn't support atomicAdd + bfloat16 natively
|
||||
device_capability = torch.cuda.get_device_capability(device)
|
||||
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
||||
maybe_warn_marlin_atomic_add(device, dtype)
|
||||
return False
|
||||
|
||||
# the performance of atomicAdd is better than global reduce
|
||||
# only when m*n is small and k is large
|
||||
return n < 2048 and k >= 2048
|
||||
return True
|
||||
|
||||
|
||||
def apply_gptq_marlin_linear(
|
||||
@@ -332,7 +378,6 @@ def apply_gptq_marlin_linear(
|
||||
wtype: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
has_zp: bool,
|
||||
is_k_full: bool,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||
@@ -346,6 +391,7 @@ def apply_gptq_marlin_linear(
|
||||
dtype=input.dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
weight_zp,
|
||||
@@ -358,7 +404,6 @@ def apply_gptq_marlin_linear(
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
has_zp=has_zp,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
@@ -391,6 +436,7 @@ def apply_awq_marlin_linear(
|
||||
dtype=input.dtype)
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
weight_zp,
|
||||
@@ -401,8 +447,6 @@ def apply_awq_marlin_linear(
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=True,
|
||||
has_zp=True,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False)
|
||||
|
||||
Reference in New Issue
Block a user