[Kernel] Support deep_gemm for linear methods (#19085)

Signed-off-by: artetaout <lulala341@gmail.com>
This commit is contained in:
artetaout
2025-06-11 15:14:45 +08:00
committed by GitHub
parent 5039ec2336
commit b8e809a057
3 changed files with 124 additions and 1 deletions

View File

@@ -3,12 +3,14 @@
# Adapted from https://github.com/sgl-project/sglang/pull/2575
import functools
import importlib.util
import json
import os
from typing import Any, Callable, Optional, Union
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -20,6 +22,7 @@ from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
@@ -98,6 +101,19 @@ def dispatch_w8a8_blockscale_func(
return w8a8_block_fp8_matmul
def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
"""
Check if DeepGEMM should be used based on the output dtype and weight shape.
DeepGEMM is only supported for bfloat16 output dtype and weights with shape
divisible by 128.
"""
return (current_platform.is_cuda()
and current_platform.is_device_capability(90) and has_deep_gemm
and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
def apply_w8a8_block_fp8_linear(
@@ -114,6 +130,29 @@ def apply_w8a8_block_fp8_linear(
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
if should_use_deepgemm(output_dtype, weight):
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
q_input, x_scale = per_token_group_quant_fp8(
input_2d,
block_size[1],
column_major_scales=True,
)
output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm(
q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=output_dtype)
if bias is not None:
output += bias
return output.to(dtype=output_dtype).view(*output_shape)
if current_platform.is_cuda():
if current_platform.has_device_capability(100):
@@ -134,7 +173,6 @@ def apply_w8a8_block_fp8_linear(
w8a8_blockscale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported)
if use_cutlass:
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=use_cutlass)