[Perf] Create TMA-aligned input scale tensor for DeepGemm on Hopper (#32619)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
group_broadcast,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import _ceil_to_ue8m0, is_deep_gemm_e8m0_used
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -170,6 +171,8 @@ def native_per_token_group_quant_fp8(
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / fp8_max
|
||||
if is_deep_gemm_e8m0_used():
|
||||
x_s = _ceil_to_ue8m0(x_s)
|
||||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
|
||||
|
||||
Reference in New Issue
Block a user