[Kernel] Support deep_gemm for linear methods (#19085)
Signed-off-by: artetaout <lulala341@gmail.com>
This commit is contained in:
@@ -3,12 +3,14 @@
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
||||
import functools
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@@ -20,6 +22,7 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||
@@ -98,6 +101,19 @@ def dispatch_w8a8_blockscale_func(
|
||||
return w8a8_block_fp8_matmul
|
||||
|
||||
|
||||
def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
|
||||
"""
|
||||
Check if DeepGEMM should be used based on the output dtype and weight shape.
|
||||
DeepGEMM is only supported for bfloat16 output dtype and weights with shape
|
||||
divisible by 128.
|
||||
"""
|
||||
|
||||
return (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90) and has_deep_gemm
|
||||
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
|
||||
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||
|
||||
|
||||
# TODO fix ROCm->Triton custom path:
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
@@ -114,6 +130,29 @@ def apply_w8a8_block_fp8_linear(
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
output_dtype = input.dtype
|
||||
|
||||
if should_use_deepgemm(output_dtype, weight):
|
||||
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True,
|
||||
)
|
||||
|
||||
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
|
||||
q_input,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=output_dtype)
|
||||
if bias is not None:
|
||||
output += bias
|
||||
return output.to(dtype=output_dtype).view(*output_shape)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.has_device_capability(100):
|
||||
@@ -134,7 +173,6 @@ def apply_w8a8_block_fp8_linear(
|
||||
|
||||
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
|
||||
use_cutlass, use_aiter_and_is_supported)
|
||||
|
||||
if use_cutlass:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=use_cutlass)
|
||||
|
||||
Reference in New Issue
Block a user