# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import TYPE_CHECKING, Any import torch from torch.nn import Module from torch.utils._python_dispatch import TorchDispatchMode import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, RoutingMethodType, ) from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, make_fp8_moe_quant_config, select_fp8_moe_backend, ) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, UnquantizedLinearMethod, ) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( init_fp8_linear_kernel, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( apply_fi_trtllm_fp8_per_tensor_moe, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, maybe_post_process_fp8_weight_block, process_fp8_input_tensor_strategy_moe, process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy, process_fp8_weight_tensor_strategy_moe, validate_fp8_block_shape, ) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( get_marlin_input_dtype, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, kFp8Dynamic128Sym, kFp8DynamicTensorSym, kFp8DynamicTokenSym, kFp8Static128BlockSym, kFp8StaticTensorSym, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, ) from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight from vllm.model_executor.parameter import ( BlockQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, ) from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( is_deep_gemm_supported, ) if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper ACTIVATION_SCHEMES = ["static", "dynamic"] logger = init_logger(__name__) class Fp8Config(QuantizationConfig): """Config class for FP8.""" def __init__( self, is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", ignored_layers: list[str] | None = None, weight_block_size: list[int] | None = None, ) -> None: super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if activation_scheme not in ACTIVATION_SCHEMES: raise ValueError(f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme self.ignored_layers = ignored_layers or [] if weight_block_size is not None: if not is_checkpoint_fp8_serialized: raise ValueError( "The block-wise quantization only supports fp8-serialized " "checkpoint for now." ) if len(weight_block_size) != 2: raise ValueError( "The quantization block size of weight must have 2 " f"dimensions, but got {len(weight_block_size)} dimensions" ) if activation_scheme != "dynamic": raise ValueError( "The block-wise quantization only supports " "dynamic activation scheme for now, but got " f"{activation_scheme} activation scheme." ) self.weight_block_size = weight_block_size @classmethod def get_name(cls) -> QuantizationMethods: return "fp8" @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: return 75 @classmethod def get_config_filenames(cls) -> list[str]: return [] def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if self.ignored_layers is not None: self.ignored_layers = hf_to_vllm_mapper.apply_list(self.ignored_layers) @classmethod def from_config(cls, config: dict[str, Any]) -> "Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = "fp8" in quant_method activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) if not ignored_layers: ignored_layers = cls.get_from_keys_or( config, ["modules_to_not_convert"], None ) return cls( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, ignored_layers=ignored_layers, weight_block_size=weight_block_size, ) def get_xpu_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": raise NotImplementedError( "FP8 quantization is not supported during xpu kernel migration." ) def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": if current_platform.is_xpu(): return self.get_xpu_quant_method(layer, prefix) if isinstance(layer, LinearBase): if is_layer_skipped( prefix=prefix, ignored_layers=self.ignored_layers, fused_mapping=self.packed_modules_mapping, ): return UnquantizedLinearMethod() if not self.is_checkpoint_fp8_serialized: online_method = Fp8OnlineLinearMethod(self) online_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return online_method else: offline_method = Fp8LinearMethod(self) offline_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return offline_method elif isinstance(layer, FusedMoE): if is_layer_skipped( prefix=prefix, ignored_layers=self.ignored_layers, fused_mapping=self.packed_modules_mapping, ): return UnquantizedFusedMoEMethod(layer.moe_config) if self.is_checkpoint_fp8_serialized: moe_quant_method = Fp8MoEMethod(self, layer) else: moe_quant_method = Fp8OnlineMoEMethod(self, layer) return moe_quant_method elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) return None def get_cache_scale(self, name: str) -> str | None: """ Check whether the param name matches the format for k/v cache scales in compressed-tensors. If this is the case, return its equivalent param name expected by vLLM :param name: param name :return: matching param name for KV cache scale in vLLM """ if name.endswith(".output_scale") and ".k_proj" in name: return name.replace(".k_proj.output_scale", ".attn.k_scale") if name.endswith(".output_scale") and ".v_proj" in name: return name.replace(".v_proj.output_scale", ".attn.v_scale") if name.endswith(".output_scale") and ".q_proj" in name: return name.replace(".q_proj.output_scale", ".attn.q_scale") if name.endswith("self_attn.prob_output_scale"): return name.replace(".prob_output_scale", ".attn.prob_scale") # If no matches, return None return None class CopyNumelCounter(TorchDispatchMode): """ Tracks total number of elements modified with `copy_`. Useful for keeping track of weight loading where underlying weights can be arbitrarily transformed (such as with `narrow`) before calling copy. """ def __init__(self): super().__init__() self.copied_numel = 0 def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} out = func(*args, **kwargs) if func == torch.ops.aten.copy_.default: self.copied_numel += args[0].numel() return out def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None: """Copies any attrs present in `old` but not in `new` to `new`""" new_attrs = set(dir(new)) attrs_to_set = {} for attr in dir(old): if attr not in new_attrs: attrs_to_set[attr] = getattr(old, attr) set_weight_attrs(new, attrs_to_set) class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. Limitations: 1. Only support float8_e4m3fn data type due to the limitation of torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) Args: quant_config: The quantization config. """ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.out_dtype = torch.get_default_dtype() # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.marlin_input_dtype = None self.use_marlin = ( not current_platform.has_device_capability(89) or envs.VLLM_TEST_FORCE_FP8_MARLIN ) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False if vllm_is_batch_invariant(): self.use_marlin = False self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() self.use_deep_gemm = is_deep_gemm_supported() self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" if self.block_quant: assert not self.act_q_static assert self.weight_block_size is not None self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( weight_group_shape=GroupShape(*self.weight_block_size), act_quant_group_shape=GroupShape(1, self.weight_block_size[0]), cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: # Use per-token quantization for better perf if dynamic and cutlass if self.act_q_static: activation_quant_key = kFp8StaticTensorSym elif cutlass_fp8_supported(): activation_quant_key = kFp8DynamicTokenSym else: activation_quant_key = kFp8DynamicTensorSym self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=activation_quant_key, weight_quant_key=kFp8StaticTensorSym, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.orig_dtype = params_dtype layer.weight_block_size = None if self.block_quant: assert self.weight_block_size is not None layer.weight_block_size = self.weight_block_size validate_fp8_block_shape( layer, input_size, output_size, input_size_per_partition, output_partition_sizes, self.weight_block_size, ) weight = create_fp8_weight_parameter( output_size_per_partition, input_size_per_partition, weight_loader ) layer.register_parameter("weight", weight) # WEIGHT SCALE if not self.block_quant: scale = create_fp8_scale_parameter( PerTensorScaleParameter, output_partition_sizes, input_size_per_partition, None, weight_loader, ) layer.register_parameter("weight_scale", scale) else: assert not self.act_q_static assert self.weight_block_size is not None scale = create_fp8_scale_parameter( BlockQuantScaleParameter, output_partition_sizes, input_size_per_partition, self.weight_block_size, weight_loader, ) # The weight_scale_inv name is intentional for deepseekv3 layer.register_parameter("weight_scale_inv", scale) # INPUT ACTIVATION SCALE if self.act_q_static: scale = create_fp8_input_scale(output_partition_sizes, weight_loader) set_weight_attrs(scale, {"scale_type": "input_scale"}) layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: Module) -> None: size_k_first = True input_scale = None # TODO(rob): refactor block quant into separate class. if self.block_quant: assert not self.act_q_static size_k_first = False weight, weight_scale_inv = process_fp8_weight_block_strategy( layer.weight, layer.weight_scale_inv ) # Update layer with new values replace_parameter(layer, "weight", weight.data) replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data) # If checkpoint not serialized fp8, quantize the weights. else: # If checkpoint is fp8 per-tensor, handle that there are N scales for N # shards in a fused module weight = layer.weight weight_scale = layer.weight_scale # If using w8a8, torch._scaled_mm needs per tensor, so # requantize the logical shards as a single weight. if not self.use_marlin: weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( weight, weight_scale, layer.logical_widths, getattr(layer, "input_scale", None), ) if self.act_q_static: assert input_scale is not None input_scale = input_scale.max() weight = weight.t() # Update layer with new values. replace_parameter(layer, "weight", weight.data) replace_parameter(layer, "weight_scale", weight_scale.data) if input_scale is not None: replace_parameter(layer, "input_scale", input_scale) else: layer.input_scale = None if self.use_marlin: prepare_fp8_layer_for_marlin( layer, size_k_first, input_dtype=self.marlin_input_dtype ) # Activations not quantized for marlin. del layer.input_scale return if self.block_quant: maybe_post_process_fp8_weight_block(layer) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. if vllm_is_batch_invariant(): if self.block_quant: assert self.weight_block_size is not None return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale_inv, input_scale=layer.input_scale, bias=bias, ) else: # per-tensor/channel: dequant to BF16 and run GEMM weight_fp8 = layer.weight.to(torch.bfloat16) weight_scale = layer.weight_scale.to(torch.bfloat16) if weight_scale.numel() == 1: # Per-tensor: simple scalar multiplication weight_bf16 = weight_fp8 * weight_scale else: # Multiple scales (fused modules like QKV) # Try to infer correct broadcasting # weight is [K, N], scale could be [num_logical_weights] # Need to figure out how to broadcast - for now just try # direct multiplication if ( weight_scale.dim() == 1 and weight_scale.shape[0] == weight_fp8.shape[0] ): # Per-row scaling weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1) else: # Fallback weight_bf16 = weight_fp8 * weight_scale return torch.nn.functional.linear(x, weight_bf16.t(), bias) if self.use_marlin: if self.block_quant: weight_scale = layer.weight_scale_inv else: weight_scale = layer.weight_scale return apply_fp8_marlin_linear( input=x, weight=layer.weight, weight_scale=weight_scale, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, input_dtype=self.marlin_input_dtype, bias=bias, ) if self.block_quant: assert self.weight_block_size is not None return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale_inv, input_scale=layer.input_scale, bias=bias, ) return self.fp8_linear.apply_weights(layer, x, bias) class Fp8OnlineLinearMethod(Fp8LinearMethod): """Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint and quantized the weights during loading.""" def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.orig_dtype = params_dtype layer.weight_block_size = None # WEIGHT def patched_weight_loader(param, loaded_weight, *args, **kwargs): # track how many elements we have updated if not hasattr(layer, "_loaded_numel"): layer._loaded_numel = 0 # when the first `loaded_weight` is about to be # loaded to `param`, materialize `param` just-in-time weight = ModelWeightParameter( data=torch.empty_like(layer.weight, device=layer._load_device), input_dim=1, output_dim=0, weight_loader=patched_weight_loader, ) _copy_missing_attrs(layer.weight, weight) layer.register_parameter("weight", weight) del layer._load_device # refresh the reference to `param` to reflect just-in-time # materialization param = layer.weight # load the current weight chunk copy_numel_counter = CopyNumelCounter() with copy_numel_counter: res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] layer._loaded_numel += copy_numel_counter.copied_numel # if we have loaded all of the elements, call # process_weights_after_loading target_loaded_numel = layer.weight.numel() if layer._loaded_numel == target_loaded_numel: self.process_weights_after_loading(layer) # Prevent the usual `process_weights_after_loading` call from doing # anything layer._already_called_process_weights_after_loading = True # Note that we keep `layer._loaded_numel` around just in case # there is logic added to vllm in the future which calls a # weight loader twice - we do not want to re-initialize in # that case. return res weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, # materialized just-in-time in `patched_weight_loader` device="meta", dtype=params_dtype, ), input_dim=1, output_dim=0, weight_loader=patched_weight_loader, ) # stash the correct device for `patched_weight_loader` layer._load_device = torch.get_default_device() layer.register_parameter("weight", weight) def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return # deferred initialization of randomly initialized weights for the # `--load_format dummy` feature if layer.weight.device == torch.device("meta"): weight = ModelWeightParameter( data=torch.empty_like(layer.weight, device=layer._load_device), input_dim=1, output_dim=0, weight_loader=layer.weight.weight_loader, ) _copy_missing_attrs(layer.weight, weight) layer.register_parameter("weight", weight) initialize_single_dummy_weight(layer.weight) # TODO(future): support block_quant in online quant path assert not self.block_quant layer.input_scale = None qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) weight = qweight.t() # Update layer with new values. replace_parameter(layer, "weight", weight.data) replace_parameter(layer, "weight_scale", weight_scale.data) if self.use_marlin: size_k_first = True prepare_fp8_layer_for_marlin( layer, size_k_first, input_dtype=self.marlin_input_dtype ) # Activations not quantized for marlin. class Fp8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. Also supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded. Args: quant_config: The quantization config. """ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): super().__init__(layer.moe_config) self.quant_config = quant_config self.weight_block_size = self.quant_config.weight_block_size self.block_quant: bool = self.weight_block_size is not None self.weight_scale_name = ( "weight_scale_inv" if self.block_quant else "weight_scale" ) # Set weight key and activation key for kernel compatibility if self.block_quant: weight_key = kFp8Static128BlockSym activation_key = kFp8Dynamic128Sym else: weight_key = kFp8StaticTensorSym activation_key = ( kFp8StaticTensorSym if self.quant_config.activation_scheme == "static" else kFp8DynamicTensorSym ) # Select Fp8 MoE backend self.fp8_backend, self.experts_cls = select_fp8_moe_backend( config=self.moe, weight_key=weight_key, activation_key=activation_key, allow_vllm_cutlass=False, ) def create_weights( self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): layer.intermediate_size_per_partition = intermediate_size_per_partition layer.hidden_size = hidden_size layer.num_experts = num_experts layer.orig_dtype = params_dtype layer.weight_block_size = None assert self.quant_config.is_checkpoint_fp8_serialized params_dtype = torch.float8_e4m3fn if self.block_quant: assert self.weight_block_size is not None layer.weight_block_size = self.weight_block_size tp_size = get_tensor_model_parallel_world_size() block_n, block_k = ( self.weight_block_size[0], self.weight_block_size[1], ) # NOTE: To ensure proper alignment of the block-wise quantization # scales, the output_size of the weights for both the gate and up # layers must be divisible by block_n. # Required by column parallel or enabling merged weights if intermediate_size_per_partition % block_n != 0: raise ValueError( f"The output_size of gate's and up's weight = " f"{intermediate_size_per_partition} is not divisible by " f"weight quantization block_n = {block_n}." ) if tp_size > 1 and intermediate_size_per_partition % block_k != 0: # Required by row parallel raise ValueError( f"The input_size of down's weight = " f"{intermediate_size_per_partition} is not divisible by " f"weight quantization block_k = {block_k}." ) # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( torch.empty( num_experts, hidden_size, intermediate_size_per_partition, dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES if not self.block_quant: # For per-tensor quant, the scales are per expert and weight. w13_scale_data = torch.ones(num_experts, 2, dtype=torch.float32) w2_scale_data = torch.ones(num_experts, dtype=torch.float32) else: # For block quant, the scales are per block (typically 128x128). w13_scale_data = torch.ones( num_experts, 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), (hidden_size + block_k - 1) // block_k, dtype=torch.float32, ) w2_scale_data = torch.ones( num_experts, (hidden_size + block_n - 1) // block_n, (intermediate_size_per_partition + block_k - 1) // block_k, dtype=torch.float32, ) w13_weight_scale = torch.nn.Parameter(w13_scale_data, requires_grad=False) w2_weight_scale = torch.nn.Parameter(w2_scale_data, requires_grad=False) # Note: name is weight_scale for tensor, weight_scale_inv for block. layer.register_parameter(f"w13_{self.weight_scale_name}", w13_weight_scale) layer.register_parameter(f"w2_{self.weight_scale_name}", w2_weight_scale) # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} if self.block_quant else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.quant_config.activation_scheme == "static": assert not self.block_quant w13_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) w2_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) else: layer.w13_input_scale = None layer.w2_input_scale = None def _setup_kernel( self, layer: FusedMoE, w13: torch.Tensor, w2: torch.Tensor, w13_scale: torch.Tensor, w2_scale: torch.Tensor, w13_input_scale: torch.Tensor | None, w2_input_scale: torch.Tensor | None, ) -> None: # Shuffle weights to runtime format. w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( fp8_backend=self.fp8_backend, layer=layer, w13=w13, w2=w2, w13_scale=w13_scale, w2_scale=w2_scale, w13_input_scale=w13_input_scale, w2_input_scale=w2_input_scale, ) # Replace parameters with updated versions. Note that this helper # function ensures the replacement is compatible with RL weight reloads. replace_parameter(layer, "w13_weight", w13) replace_parameter(layer, "w2_weight", w2) replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale) replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) # Setup modular kernel for TP case and naive DP/EP case. # In non-naive DP/EP case, we will create a ModularKernelMethod. # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel # in both cases. self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: assert self.experts_cls is not None self.moe_mk, self.use_inplace = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, experts_cls=self.experts_cls, routing_tables=layer._maybe_init_expert_routing_tables(), shared_experts=layer.shared_experts, ) def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return # Allow for accessing weights and scales in standard way. w13 = layer.w13_weight w2 = layer.w2_weight w13_scale = getattr(layer, f"w13_{self.weight_scale_name}") w2_scale = getattr(layer, f"w2_{self.weight_scale_name}") w13_input_scale = layer.w13_input_scale w2_input_scale = layer.w2_input_scale # MI300x and MI325x use FNUZ format for FP8. Convert if needed. if current_platform.is_fp8_fnuz(): w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz( w13, w13_scale, w13_input_scale, ) w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( w2, w2_scale, w2_input_scale, ) # Per tensor kernels require single activation scale. Use the max. if self.quant_config.activation_scheme == "static": assert not self.block_quant assert w13_input_scale is not None and w2_input_scale is not None w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe( w13_input_scale, w2_input_scale ) replace_parameter(layer, "w13_input_scale", w13_input_scale) replace_parameter(layer, "w2_input_scale", w2_input_scale) # Per tensor kernels require single weight scale for w13 per expert, but # on disk there is a scale for w1 and w3. Use the max to requantize. if not self.block_quant: shard_size = layer.intermediate_size_per_partition w13, w13_scale = process_fp8_weight_tensor_strategy_moe( w13, w13_scale, shard_size, layer.local_num_experts ) # Shuffle weights to runtime format and setup kernel. self._setup_kernel( layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." ) def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> FusedMoEPermuteExpertsUnpermute: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." ) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: # TRTLLM does not use Modular Kernel. if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: return None w1_scale = getattr(layer, f"w13_{self.weight_scale_name}") w2_scale = getattr(layer, f"w2_{self.weight_scale_name}") a1_scale = layer.w13_input_scale a2_scale = layer.w2_input_scale return make_fp8_moe_quant_config( fp8_backend=self.fp8_backend, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, block_shape=self.weight_block_size, ) @property def supports_eplb(self) -> bool: return True @property def allow_inplace(self) -> bool: return True @property def is_monolithic(self) -> bool: return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM def apply_monolithic( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM # TODO(rob): convert this to MK. if layer.enable_eplb: raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.") assert layer.activation == "silu", ( f"Expected 'silu' activation but got {layer.activation}" ) if self.block_quant: import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 e_score_correction_bias = ( layer.e_score_correction_bias.to(x.dtype) if layer.e_score_correction_bias is not None else None ) routing_method_type = layer.routing_method_type return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( routing_logits=router_logits.to(torch.float32) if routing_method_type == RoutingMethodType.DeepSeekV3 else router_logits, routing_bias=e_score_correction_bias, x=x, w13_weight=layer.w13_weight, w13_weight_scale_inv=layer.w13_weight_scale_inv, w2_weight=layer.w2_weight, w2_weight_scale_inv=layer.w2_weight_scale_inv, global_num_experts=layer.global_num_experts, top_k=layer.top_k, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, intermediate_size=layer.intermediate_size_per_partition, expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, block_shape=self.weight_block_size, routing_method_type=routing_method_type, routed_scaling=layer.routed_scaling_factor, ) else: return apply_fi_trtllm_fp8_per_tensor_moe( layer=layer, hidden_states=x, router_logits=router_logits, routing_bias=layer.e_score_correction_bias, global_num_experts=layer.global_num_experts, top_k=layer.top_k, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) def apply( self, layer: FusedMoE, x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.moe_mk is not None assert not self.is_monolithic return self.moe_mk( x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids, inplace=self.use_inplace, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) class Fp8OnlineMoEMethod(Fp8MoEMethod): """MoE method for online FP8 quantization. Supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded. Args: quant_config: The quantization config. """ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): super().__init__(quant_config, layer) assert not quant_config.is_checkpoint_fp8_serialized assert quant_config.activation_scheme == "dynamic" assert quant_config.weight_block_size is None def create_weights( self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): layer.intermediate_size_per_partition = intermediate_size_per_partition layer.hidden_size = hidden_size layer.num_experts = num_experts layer.orig_dtype = params_dtype layer.weight_block_size = None # We are doing online quantization, patch the weight loaded # to call `process_weights_after_loading` in a streaming fashion # as soon as the last weight chunk is loaded. weight_loader = extra_weight_attrs["weight_loader"] # create a new holder to prevent modifying behavior of any other # objects which might depend on the old one new_extra_weight_attrs = extra_weight_attrs def patched_weight_loader(param, loaded_weight, *args, **kwargs): # add a counter to track how many elements we have updated if not hasattr(layer, "_loaded_numel"): layer._loaded_numel = 0 # save the ids of original w13 and w2 so that we can # distinguish which one `param` should map to further # down in this file layer._w13_weight_orig_id = id(layer.w13_weight) layer._w2_weight_orig_id = id(layer.w2_weight) # when the first `loaded_weight` is about to be # loaded to `param`, materialize `param` just-in-time w13_weight = torch.nn.Parameter( torch.empty_like(layer.w13_weight, device=layer._load_device), requires_grad=False, ) set_weight_attrs(w13_weight, extra_weight_attrs) _copy_missing_attrs(layer.w13_weight, w13_weight) layer.register_parameter("w13_weight", w13_weight) w2_weight = torch.nn.Parameter( torch.empty_like(layer.w2_weight, device=layer._load_device), requires_grad=False, ) set_weight_attrs(w2_weight, extra_weight_attrs) _copy_missing_attrs(layer.w2_weight, w2_weight) layer.register_parameter("w2_weight", w2_weight) del layer._load_device # refresh the reference to `param` to reflect just-in-time # materialization if id(param) == layer._w13_weight_orig_id: param = layer.w13_weight elif id(param) == layer._w2_weight_orig_id: param = layer.w2_weight # load the current weight chunk copy_numel_counter = CopyNumelCounter() with copy_numel_counter: res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] layer._loaded_numel += copy_numel_counter.copied_numel # if we have loaded all of the elements, call # process_weights_after_loading target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() if layer._loaded_numel == target_loaded_numel: self.process_weights_after_loading(layer) # Prevent the usual `process_weights_after_loading` call # from doing anything layer._already_called_process_weights_after_loading = True # Note that we keep `layer._loaded_numel`, # `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id` # around because if EP is on, weight loaders for non-local # experts will run but not actually copy any elements, and we # need to not re-initialize in that case. return res new_extra_weight_attrs["weight_loader"] = patched_weight_loader extra_weight_attrs = new_extra_weight_attrs # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, # materialized just-in-time in `patched_weight_loader` device="meta", dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( torch.empty( num_experts, hidden_size, intermediate_size_per_partition, # materialized just-in-time in `patched_weight_loader` device="meta", dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # stash the correct device for `patched_weight_loader` layer._load_device = torch.get_default_device() # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. w13_weight_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) w2_weight_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) layer.w13_input_scale = None layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_already_called_process_weights_after_loading", False): return # deferred initialization of randomly initialized weights for the # `--load_format dummy` feature if layer.w13_weight.device == torch.device("meta"): w13_weight = torch.nn.Parameter( torch.empty_like(layer.w13_weight, device=layer._load_device), requires_grad=False, ) set_weight_attrs( w13_weight, {"weight_loader": layer.w13_weight.weight_loader} ) _copy_missing_attrs(layer.w13_weight, w13_weight) layer.register_parameter("w13_weight", w13_weight) initialize_single_dummy_weight(layer.w13_weight) if layer.w2_weight.device == torch.device("meta"): w2_weight = torch.nn.Parameter( torch.empty_like(layer.w2_weight, device=layer._load_device), requires_grad=False, ) set_weight_attrs( w2_weight, {"weight_loader": layer.w2_weight.weight_loader} ) _copy_missing_attrs(layer.w2_weight, w2_weight) layer.register_parameter("w2_weight", w2_weight) initialize_single_dummy_weight(layer.w2_weight) # If checkpoint is fp16, quantize in place. fp8_dtype = current_platform.fp8_dtype() w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) w13_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale for expert in range(layer.local_num_experts): w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant( layer.w13_weight[expert, :, :] ) w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant( layer.w2_weight[expert, :, :] ) # Shuffle weights to runtime format and setup kernel. self._setup_kernel( layer, w13, w2, w13_scale, w2_scale, layer.w13_input_scale, layer.w2_input_scale, ) class Fp8KVCacheMethod(BaseKVCacheMethod): """ Supports loading kv-cache scaling factors from FP8 checkpoints. """ def __init__(self, quant_config: Fp8Config): super().__init__(quant_config)