[XPU][Feature] fp8 online quantization support for XPU (#23148)

Signed-off-by: Yan Ma <yan.ma@intel.com>
Co-authored-by: Qiming Zhang <qiming1.zhang@intel.com>
This commit is contained in:
Yan Ma
2025-09-02 12:06:53 +08:00
committed by GitHub
parent 1fa1d6a9a0
commit 7be0cb8e9e
4 changed files with 242 additions and 2 deletions

View File

@@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Optional, Union
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -349,3 +350,56 @@ class ipex_ops:
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
@staticmethod
def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function is designed for both static and dynamic quantization:
If you provide the scale, it will use static scaling and if you omit
it, the scale will be determined dynamically. Currently, XPU platform
only supports dynamic quantization. The function also allows optional
padding of the output tensors for downstream kernels that will benefit
from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
assert scale is None, "only dynamic fp8 quantization supported on XPU"
assert not use_per_token_if_dynamic, (
"per token dynamic fp8 quantization not supported on XPU")
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops.torch_ipex.dynamic_scaled_fp8_quant(output, input, scale)
return output, scale