[Perf] Use Triton instead of Torch for DeepGEMM Per Token Group Quant (#20841)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -256,6 +257,7 @@ def _per_token_group_quant_fp8(
|
||||
# Information for float8
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
use_ue8m0: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK: tl.constexpr,
|
||||
):
|
||||
@@ -285,7 +287,8 @@ def _per_token_group_quant_fp8(
|
||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
# Quant
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
y_s = _absmax / fp8_max
|
||||
scale_raw = _absmax / fp8_max
|
||||
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||
@@ -309,6 +312,7 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
# Information for float8
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
use_ue8m0: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK: tl.constexpr,
|
||||
):
|
||||
@@ -347,7 +351,8 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
# Quant
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
y_s = _absmax / fp8_max
|
||||
scale_raw = _absmax / fp8_max
|
||||
y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + cols, y_q, mask=mask)
|
||||
@@ -373,9 +378,11 @@ def per_token_group_quant_fp8(
|
||||
is supported for now.
|
||||
column_major_scales: Outputs scales in column major.
|
||||
out_q: Optional output tensor. If not provided, function will create.
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor.
|
||||
"""
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
assert (x.shape[-1] % group_size == 0), (
|
||||
@@ -418,6 +425,7 @@ def per_token_group_quant_fp8(
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=is_blackwell_deep_gemm_used(),
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
@@ -433,6 +441,7 @@ def per_token_group_quant_fp8(
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
use_ue8m0=is_blackwell_deep_gemm_used(),
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
|
||||
Reference in New Issue
Block a user