[Perf] Use Triton instead of Torch for DeepGEMM Per Token Group Quant (#20841)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-07-12 22:38:45 -04:00
committed by GitHub
parent f45a332886
commit 42d440c22b
6 changed files with 26 additions and 42 deletions

View File

@@ -20,6 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
logger = init_logger(__name__)
@@ -256,6 +257,7 @@ def _per_token_group_quant_fp8(
# Information for float8
fp8_min,
fp8_max,
use_ue8m0: tl.constexpr,
# Meta-parameters
BLOCK: tl.constexpr,
):
@@ -285,7 +287,8 @@ def _per_token_group_quant_fp8(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
scale_raw = _absmax / fp8_max
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
@@ -309,6 +312,7 @@ def _per_token_group_quant_fp8_colmajor(
# Information for float8
fp8_min,
fp8_max,
use_ue8m0: tl.constexpr,
# Meta-parameters
BLOCK: tl.constexpr,
):
@@ -347,7 +351,8 @@ def _per_token_group_quant_fp8_colmajor(
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
scale_raw = _absmax / fp8_max
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
@@ -373,9 +378,11 @@ def per_token_group_quant_fp8(
is supported for now.
column_major_scales: Outputs scales in column major.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
@@ -418,6 +425,7 @@ def per_token_group_quant_fp8(
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=is_blackwell_deep_gemm_used(),
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
@@ -433,6 +441,7 @@ def per_token_group_quant_fp8(
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=is_blackwell_deep_gemm_used(),
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,