[XPU][Feature] fp8 online quantization support for XPU (#23148)

Signed-off-by: Yan Ma <yan.ma@intel.com>
Co-authored-by: Qiming Zhang <qiming1.zhang@intel.com>
This commit is contained in:
Yan Ma
2025-09-02 12:06:53 +08:00
committed by GitHub
parent 1fa1d6a9a0
commit 7be0cb8e9e
4 changed files with 242 additions and 2 deletions

View File

@@ -137,10 +137,35 @@ class Fp8Config(QuantizationConfig):
ignored_layers=ignored_layers,
weight_block_size=weight_block_size)
def get_xpu_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention
from vllm.model_executor.layers.quantization.ipex_quant import (
XPUFp8LinearMethod, XPUFp8MoEMethod)
fp8_config = Fp8Config(
is_checkpoint_fp8_serialized=self.is_checkpoint_fp8_serialized,
activation_scheme=self.activation_scheme,
ignored_layers=self.ignored_layers,
weight_block_size=self.weight_block_size)
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
return XPUFp8LinearMethod(fp8_config)
elif isinstance(layer, FusedMoE):
return XPUFp8MoEMethod(fp8_config, layer)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if current_platform.is_xpu():
return self.get_xpu_quant_method(layer, prefix)
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix=prefix,
ignored_layers=self.ignored_layers,

View File

@@ -1,11 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Any, Callable, Optional
import torch
from packaging import version
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -13,7 +18,10 @@ from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
Fp8LinearMethod)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
MIN_IPEX_VERSION = "2.6.0"
@@ -251,3 +259,152 @@ class IPEXAWQLinearMethod(AWQLinearMethod):
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
class XPUFp8LinearMethod(Fp8LinearMethod):
def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config)
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
# Update the layer with the new values.
layer.weight = Parameter(qweight, requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight.data
weight_scale = layer.weight_scale.data
output = torch.ops.torch_ipex.fp8_gemm_w8a16(x, weight, True,
weight_scale, bias)
return output
class XPUFp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(layer.moe_config)
self.quant_config = quant_config
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# INPUT_SCALES
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if not self.quant_config.is_checkpoint_fp8_serialized:
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
layer.local_num_experts,
dtype=torch.float32,
device=w13_weight.device),
requires_grad=False)
for expert in range(layer.local_num_experts):
w13_weight[expert, :, :], layer.w13_weight_scale[
expert] = ops.scaled_fp8_quant(
layer.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], layer.w2_weight_scale[
expert] = ops.scaled_fp8_quant(
layer.w2_weight.data[expert, :, :])
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
import intel_extension_for_pytorch as ipex
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
w1_scale_inv=layer.w13_weight_scale,
w2_scale_inv=layer.w2_weight_scale,
a1_scale_inv=layer.w13_input_scale,
a2_scale_inv=layer.w2_input_scale,
use_prepack=True,
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return layer.ipex_fusion(
x,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
custom_routing_function=custom_routing_function,
)