[Feature] Add VLLM_USE_DEEP_GEMM_E8M0 Env to Control E8M0 Scale (#21968)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-08-11 12:39:08 -04:00
committed by GitHub
parent 8e13d9fe6d
commit f7dcce7a4a
9 changed files with 65 additions and 39 deletions

View File

@@ -20,7 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used
logger = init_logger(__name__)
@@ -394,10 +394,8 @@ def per_token_group_quant_fp8(
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor.
"""
# TODO(wentao): refactor this
# use_ue8m0 should be a global flag that could be set by user
if use_ue8m0 is None:
use_ue8m0 = is_blackwell_deep_gemm_used()
use_ue8m0 = is_blackwell_deep_gemm_e8m0_used()
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "