Add option to use DeepGemm contiguous grouped gemm kernel for fused MoE operations. (#13932)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -37,6 +38,14 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
def _is_col_major(x: torch.Tensor) -> bool:
|
||||
assert x.dim() == 3
|
||||
b, m, n = x.shape
|
||||
return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
@@ -424,6 +433,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
# Check for DeepGemm support.
|
||||
self.allow_deep_gemm = False
|
||||
if envs.VLLM_USE_DEEP_GEMM:
|
||||
if not has_deep_gemm:
|
||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(90)):
|
||||
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
||||
self.allow_deep_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
@@ -585,6 +607,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
if self.allow_deep_gemm:
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
if _is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = \
|
||||
dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous()
|
||||
if _is_col_major(layer.w2_weight_scale_inv):
|
||||
layer.w2_weight_scale_inv = \
|
||||
dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous()
|
||||
|
||||
return
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
@@ -773,6 +808,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user