Add option to use DeepGemm contiguous grouped gemm kernel for fused MoE operations. (#13932)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-04-01 12:07:43 -04:00
committed by GitHub
parent a57a3044aa
commit e59ca942f5
6 changed files with 773 additions and 114 deletions

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import Any, Callable, Dict, List, Optional
import torch
@@ -37,6 +38,14 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def _is_col_major(x: torch.Tensor) -> bool:
assert x.dim() == 3
b, m, n = x.shape
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
@@ -424,6 +433,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm:
logger.warning_once("Failed to import DeepGemm kernels.")
elif (current_platform.is_cuda()
and current_platform.has_device_capability(90)):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True
else:
logger.warning_once(
"DeepGemm not supported on the current platform.")
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@@ -585,6 +607,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.allow_deep_gemm:
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
if _is_col_major(layer.w2_weight_scale_inv):
layer.w2_weight_scale_inv = \
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
return
# If checkpoint is fp16, quantize in place.
@@ -773,6 +808,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)