[XPU] enable fp8 online streaming quantization (#30944)

Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma
2025-12-20 21:45:27 +08:00
committed by GitHub
parent 1501a4070e
commit 560ae9638c
2 changed files with 29 additions and 107 deletions

View File

@@ -124,11 +124,13 @@ def get_fp8_moe_backend(
block_quant: bool, block_quant: bool,
moe_parallel_config: FusedMoEParallelConfig, moe_parallel_config: FusedMoEParallelConfig,
with_lora_support: bool, with_lora_support: bool,
) -> Fp8MoeBackend: ) -> Fp8MoeBackend | None:
""" """
Select the primary FP8 MoE backend Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime. Note: Shape-specific fallbacks may still occur at runtime.
""" """
if current_platform.is_xpu():
return None
if with_lora_support: if with_lora_support:
return Fp8MoeBackend.TRITON return Fp8MoeBackend.TRITON
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
@@ -292,6 +294,13 @@ class Fp8Config(QuantizationConfig):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return XPUFp8LinearMethod(fp8_config) return XPUFp8LinearMethod(fp8_config)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
return XPUFp8MoEMethod(fp8_config, layer) return XPUFp8MoEMethod(fp8_config, layer)
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self) return Fp8KVCacheMethod(self)
@@ -1107,7 +1116,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None: ) -> mk.FusedMoEPrepareAndFinalize | None:
if ( if (
self.rocm_aiter_moe_enabled current_platform.is_xpu()
or self.rocm_aiter_moe_enabled
or self.use_marlin or self.use_marlin
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
): ):

View File

@@ -6,13 +6,8 @@ from typing import Any, Optional
import torch import torch
from packaging import version from packaging import version
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm._ipex_ops import ipex_ops as ops from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
@@ -24,14 +19,14 @@ from vllm.model_executor.layers.quantization import (
QuantizationMethods, QuantizationMethods,
) )
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
Fp8LinearMethod,
Fp8OnlineMoEMethod,
)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.utils import replace_parameter
maybe_create_device_identity,
)
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
MIN_IPEX_VERSION = "2.6.0" MIN_IPEX_VERSION = "2.6.0"
@@ -309,44 +304,15 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config) super().__init__(quant_config)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
# Update the layer with the new values. # Update the layer with the new values.
layer.weight = Parameter(qweight, requires_grad=False) replace_parameter(layer, "weight", qweight.data)
layer.weight_scale = Parameter(weight_scale, requires_grad=False) replace_parameter(layer, "weight_scale", weight_scale.data)
layer.input_scale = None layer.input_scale = None
def apply( def apply(
@@ -363,69 +329,14 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
return output return output
class XPUFp8MoEMethod(FusedMoEMethodBase): class XPUFp8MoEMethod(Fp8OnlineMoEMethod):
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(layer.moe_config) super().__init__(quant_config, layer)
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# INPUT_SCALES
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
fp8_dtype = current_platform.fp8_dtype() fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
@@ -448,8 +359,9 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
) )
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) replace_parameter(layer, "w13_weight", w13_weight)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) replace_parameter(layer, "w2_weight", w2_weight)
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts