[XPU][Feature] fp8 online quantization support for XPU (#23148)
Signed-off-by: Yan Ma <yan.ma@intel.com> Co-authored-by: Qiming Zhang <qiming1.zhang@intel.com>
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -349,3 +350,56 @@ class ipex_ops:
|
||||
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
num_token_padding: Optional[int] = None,
|
||||
scale_ub: Optional[torch.Tensor] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||
|
||||
This function is designed for both static and dynamic quantization:
|
||||
If you provide the scale, it will use static scaling and if you omit
|
||||
it, the scale will be determined dynamically. Currently, XPU platform
|
||||
only supports dynamic quantization. The function also allows optional
|
||||
padding of the output tensors for downstream kernels that will benefit
|
||||
from padding.
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP8
|
||||
scale: Optional scaling factor for the FP8 quantization
|
||||
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||
per token case
|
||||
num_token_padding: If specified, pad the first dimension
|
||||
of the output to at least this value.
|
||||
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||
in the dynamic quantization case.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
||||
scaling factor.
|
||||
"""
|
||||
# This code assumes batch_dim and num_tokens are flattened
|
||||
assert (input.ndim == 2)
|
||||
shape: Union[tuple[int, int], torch.Size] = input.shape
|
||||
out_dtype: torch.dtype = current_platform.fp8_dtype()
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
if output is None:
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
else:
|
||||
assert num_token_padding is None, \
|
||||
"padding not supported if output passed in"
|
||||
assert output.dtype == out_dtype
|
||||
assert scale is None, "only dynamic fp8 quantization supported on XPU"
|
||||
assert not use_per_token_if_dynamic, (
|
||||
"per token dynamic fp8 quantization not supported on XPU")
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
||||
|
||||
Reference in New Issue
Block a user