[Feature] Add Hopper DeepGEMM E8M0 for DeepSeekV3.1 scale_fmt (#23666)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Wentao Ye
2025-08-27 10:09:08 -04:00
committed by GitHub
parent 513c1fe255
commit 3af47c3cc6
10 changed files with 68 additions and 53 deletions

View File

@@ -48,8 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
is_deep_gemm_supported)
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_moe
if TYPE_CHECKING:
@@ -427,7 +426,7 @@ class Fp8LinearMethod(LinearMethodBase):
# On B200, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale
# at the same time.
if is_blackwell_deep_gemm_e8m0_used():
if is_deep_gemm_e8m0_used():
assert layer.weight_block_size is not None
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(
@@ -734,7 +733,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# DeepGemm scales need to be transposed and aligned. We try to do
# it ahead of time for performance reasons.
if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
if self.allow_deep_gemm and not is_deep_gemm_e8m0_used():
# Lazy import to avoid CUDA initialization problems.
if _is_col_major(layer.w13_weight_scale_inv):
layer.w13_weight_scale_inv = \
@@ -871,7 +870,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
if is_blackwell_deep_gemm_e8m0_used():
if is_deep_gemm_e8m0_used():
assert layer.weight_block_size is not None
# Re-quantise the expert weights so their scales are UE8M0.
block_sz = tuple(layer.weight_block_size)

View File

@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear)
logger = init_logger(__name__)
@@ -385,7 +385,7 @@ def per_token_group_quant_fp8(
scaling factor.
"""
if use_ue8m0 is None:
use_ue8m0 = is_blackwell_deep_gemm_e8m0_used()
use_ue8m0 = is_deep_gemm_e8m0_used()
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "