[Feature] Add VLLM_USE_DEEP_GEMM_E8M0 Env to Control E8M0 Scale (#21968)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -45,7 +45,8 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported)
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -415,10 +416,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Activations not quantized for marlin.
|
||||
del layer.input_scale
|
||||
|
||||
# On B200, DeepGemm only support E8M0 scale, which means we need to
|
||||
# On B200, if E8M0 for DeepGemm is used, we need to
|
||||
# requantize the weight and input to the specific scale
|
||||
# at the same time.
|
||||
if is_blackwell_deep_gemm_used():
|
||||
if is_blackwell_deep_gemm_e8m0_used():
|
||||
assert layer.weight_block_size is not None
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(
|
||||
@@ -505,15 +506,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
elif not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using "
|
||||
"DeepGemm kernels")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(90)):
|
||||
elif (is_deep_gemm_supported()):
|
||||
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
||||
self.allow_deep_gemm = True
|
||||
elif (current_platform.is_cuda()
|
||||
and is_blackwell_deep_gemm_used()):
|
||||
logger.info_once("Using DeepGemm SM100 kernels for "
|
||||
"Fp8MoEMethod.")
|
||||
self.allow_deep_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"DeepGemm not supported on the current platform.")
|
||||
@@ -725,7 +720,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# DeepGemm scales need to be transposed and aligned. We try to do
|
||||
# it ahead of time for performance reasons.
|
||||
if self.allow_deep_gemm and not is_blackwell_deep_gemm_used():
|
||||
if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used():
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
if _is_col_major(layer.w13_weight_scale_inv):
|
||||
layer.w13_weight_scale_inv = \
|
||||
@@ -851,7 +846,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
if is_blackwell_deep_gemm_used():
|
||||
if is_blackwell_deep_gemm_e8m0_used():
|
||||
assert layer.weight_block_size is not None
|
||||
# Re-quantise the expert weights so their scales are UE8M0.
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
|
||||
@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -394,10 +394,8 @@ def per_token_group_quant_fp8(
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor.
|
||||
"""
|
||||
# TODO(wentao): refactor this
|
||||
# use_ue8m0 should be a global flag that could be set by user
|
||||
if use_ue8m0 is None:
|
||||
use_ue8m0 = is_blackwell_deep_gemm_used()
|
||||
use_ue8m0 = is_blackwell_deep_gemm_e8m0_used()
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
assert (x.shape[-1] % group_size == 0), (
|
||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||
|
||||
Reference in New Issue
Block a user