[FP8]add FP8 WoQ kernel abstraction. (#32929)
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
This commit is contained in:
@@ -72,6 +72,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
|||||||
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
||||||
FlashInferFP8ScaledMMLinearKernel,
|
FlashInferFP8ScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.kernels.linear.scaled_mm.marlin import (
|
||||||
|
MarlinFP8ScaledMMLinearKernel,
|
||||||
|
)
|
||||||
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
|
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
|
||||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||||
@@ -104,6 +107,7 @@ _POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]]
|
|||||||
# in priority/performance order (when available)
|
# in priority/performance order (when available)
|
||||||
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
|
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
|
||||||
PlatformEnum.CUDA: [
|
PlatformEnum.CUDA: [
|
||||||
|
MarlinFP8ScaledMMLinearKernel,
|
||||||
FlashInferFP8ScaledMMLinearKernel,
|
FlashInferFP8ScaledMMLinearKernel,
|
||||||
CutlassFP8ScaledMMLinearKernel,
|
CutlassFP8ScaledMMLinearKernel,
|
||||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
|||||||
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
||||||
FlashInferFP8ScaledMMLinearKernel,
|
FlashInferFP8ScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.kernels.linear.scaled_mm.marlin import (
|
||||||
|
MarlinFP8ScaledMMLinearKernel,
|
||||||
|
)
|
||||||
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
|
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
|
||||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||||
@@ -46,6 +49,7 @@ __all__ = [
|
|||||||
"CutlassFP8ScaledMMLinearKernel",
|
"CutlassFP8ScaledMMLinearKernel",
|
||||||
"CutlassInt8ScaledMMLinearKernel",
|
"CutlassInt8ScaledMMLinearKernel",
|
||||||
"FlashInferFP8ScaledMMLinearKernel",
|
"FlashInferFP8ScaledMMLinearKernel",
|
||||||
|
"MarlinFP8ScaledMMLinearKernel",
|
||||||
"ChannelWiseTorchFP8ScaledMMLinearKernel",
|
"ChannelWiseTorchFP8ScaledMMLinearKernel",
|
||||||
"PerTensorTorchFP8ScaledMMLinearKernel",
|
"PerTensorTorchFP8ScaledMMLinearKernel",
|
||||||
"RowWiseTorchFP8ScaledMMLinearKernel",
|
"RowWiseTorchFP8ScaledMMLinearKernel",
|
||||||
|
|||||||
120
vllm/model_executor/kernels/linear/scaled_mm/marlin.py
Normal file
120
vllm/model_executor/kernels/linear/scaled_mm/marlin.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
process_fp8_weight_block_strategy,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
|
apply_fp8_marlin_linear,
|
||||||
|
is_fp8_marlin_supported,
|
||||||
|
prepare_fp8_layer_for_marlin,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
kFp8Static128BlockSym,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.utils import replace_parameter
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .ScaledMMLinearKernel import (
|
||||||
|
FP8ScaledMMLinearKernel,
|
||||||
|
FP8ScaledMMLinearLayerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MarlinFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||||
|
"""
|
||||||
|
FP8 Marlin kernel for GPUs that lack FP8 hardware support.
|
||||||
|
Leverages the Marlin kernel for fast weight-only FP8 quantization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_supported(
|
||||||
|
cls, compute_capability: int | None = None
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False, "requires CUDA."
|
||||||
|
# Check if platform supports FP8 Marlin
|
||||||
|
if not is_fp8_marlin_supported():
|
||||||
|
return False, "FP8 Marlin requires compute capability 7.5 or higher"
|
||||||
|
if vllm_is_batch_invariant():
|
||||||
|
return False, "FP8 Marlin not supported for batch invariant execution."
|
||||||
|
if (
|
||||||
|
compute_capability is not None
|
||||||
|
and compute_capability >= 89
|
||||||
|
and not envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"To apply FP8 Marlin on high-capability GPUs, please set "
|
||||||
|
"VLLM_TEST_FORCE_FP8_MARLIN=1",
|
||||||
|
)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||||
|
) -> None:
|
||||||
|
super().__init__(c, layer_param_names)
|
||||||
|
self.marlin_input_dtype = None
|
||||||
|
self.block_quant = self.config.weight_quant_key in {kFp8Static128BlockSym}
|
||||||
|
self.size_k_first = not self.block_quant
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
if self.block_quant:
|
||||||
|
weight, weight_scale_inv = process_fp8_weight_block_strategy(
|
||||||
|
layer.weight, layer.weight_scale_inv
|
||||||
|
)
|
||||||
|
# Update layer with new values
|
||||||
|
replace_parameter(layer, "weight", weight.data)
|
||||||
|
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
|
||||||
|
else:
|
||||||
|
weight = layer.weight.t()
|
||||||
|
replace_parameter(layer, "weight", weight.data)
|
||||||
|
layer.input_scale = None
|
||||||
|
prepare_fp8_layer_for_marlin(
|
||||||
|
layer, self.size_k_first, input_dtype=self.marlin_input_dtype
|
||||||
|
)
|
||||||
|
del layer.input_scale
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.block_quant:
|
||||||
|
weight_scale = layer.weight_scale_inv
|
||||||
|
else:
|
||||||
|
weight_scale = layer.weight_scale
|
||||||
|
return apply_fp8_marlin_linear(
|
||||||
|
input=x,
|
||||||
|
weight=layer.weight,
|
||||||
|
weight_scale=weight_scale,
|
||||||
|
workspace=layer.workspace,
|
||||||
|
size_n=layer.output_size_per_partition,
|
||||||
|
size_k=layer.input_size_per_partition,
|
||||||
|
input_dtype=self.marlin_input_dtype,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_scaled_mm(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
bias: torch.Tensor | None,
|
||||||
|
output_shape: list,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pass
|
||||||
@@ -22,7 +22,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
apply_fp8_marlin_linear,
|
|
||||||
prepare_fp8_layer_for_marlin,
|
prepare_fp8_layer_for_marlin,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
@@ -177,15 +176,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.quant_config.use_marlin:
|
|
||||||
return apply_fp8_marlin_linear(
|
|
||||||
input=x,
|
|
||||||
weight=layer.weight,
|
|
||||||
weight_scale=layer.weight_scale,
|
|
||||||
workspace=layer.workspace,
|
|
||||||
size_n=layer.output_size_per_partition,
|
|
||||||
size_k=layer.input_size_per_partition,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.fp8_linear.apply_weights(layer, x, bias)
|
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import torch
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
@@ -16,6 +15,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.kernels.linear import (
|
from vllm.model_executor.kernels.linear import (
|
||||||
init_fp8_linear_kernel,
|
init_fp8_linear_kernel,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.kernels.linear.scaled_mm import MarlinFP8ScaledMMLinearKernel
|
||||||
from vllm.model_executor.layers.attention import Attention
|
from vllm.model_executor.layers.attention import Attention
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_is_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
@@ -61,10 +61,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
get_marlin_input_dtype,
|
get_marlin_input_dtype,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
|
||||||
apply_fp8_marlin_linear,
|
|
||||||
prepare_fp8_layer_for_marlin,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
GroupShape,
|
||||||
is_layer_skipped,
|
is_layer_skipped,
|
||||||
@@ -280,15 +276,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
# kernel for fast weight-only FP8 quantization
|
# kernel for fast weight-only FP8 quantization
|
||||||
self.marlin_input_dtype = None
|
self.marlin_input_dtype = None
|
||||||
self.use_marlin = (
|
|
||||||
not current_platform.has_device_capability(89)
|
|
||||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
|
||||||
)
|
|
||||||
# Disable marlin for rocm
|
|
||||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
|
||||||
self.use_marlin = False
|
|
||||||
if vllm_is_batch_invariant():
|
|
||||||
self.use_marlin = False
|
|
||||||
|
|
||||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
|
||||||
self.use_deep_gemm = is_deep_gemm_supported()
|
self.use_deep_gemm = is_deep_gemm_supported()
|
||||||
@@ -297,7 +284,28 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
self.block_quant = self.weight_block_size is not None
|
self.block_quant = self.weight_block_size is not None
|
||||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||||
|
|
||||||
|
# Use per-token quantization for better perf if dynamic and cutlass
|
||||||
|
if self.act_q_static:
|
||||||
|
activation_quant_key = kFp8StaticTensorSym
|
||||||
|
elif cutlass_fp8_supported():
|
||||||
|
activation_quant_key = kFp8DynamicTokenSym
|
||||||
|
else:
|
||||||
|
activation_quant_key = kFp8DynamicTensorSym
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
|
weight_quant_key = kFp8Static128BlockSym
|
||||||
|
else:
|
||||||
|
weight_quant_key = kFp8StaticTensorSym
|
||||||
|
|
||||||
|
self.fp8_linear = init_fp8_linear_kernel(
|
||||||
|
activation_quant_key=activation_quant_key,
|
||||||
|
weight_quant_key=weight_quant_key,
|
||||||
|
out_dtype=torch.get_default_dtype(),
|
||||||
|
module_name=self.__class__.__name__,
|
||||||
|
)
|
||||||
|
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
|
||||||
|
|
||||||
|
if self.block_quant and not self.use_marlin:
|
||||||
assert not self.act_q_static
|
assert not self.act_q_static
|
||||||
assert self.weight_block_size is not None
|
assert self.weight_block_size is not None
|
||||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||||
@@ -306,21 +314,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Use per-token quantization for better perf if dynamic and cutlass
|
|
||||||
if self.act_q_static:
|
|
||||||
activation_quant_key = kFp8StaticTensorSym
|
|
||||||
elif cutlass_fp8_supported():
|
|
||||||
activation_quant_key = kFp8DynamicTokenSym
|
|
||||||
else:
|
|
||||||
activation_quant_key = kFp8DynamicTensorSym
|
|
||||||
|
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
|
||||||
activation_quant_key=activation_quant_key,
|
|
||||||
weight_quant_key=kFp8StaticTensorSym,
|
|
||||||
out_dtype=torch.get_default_dtype(),
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@@ -387,12 +380,18 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.register_parameter("input_scale", scale)
|
layer.register_parameter("input_scale", scale)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
size_k_first = True
|
if self.use_marlin:
|
||||||
|
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid
|
||||||
|
# AttributeError if backend selection changes.
|
||||||
|
if hasattr(self.fp8_linear, "marlin_input_dtype"):
|
||||||
|
self.fp8_linear.marlin_input_dtype = self.marlin_input_dtype
|
||||||
|
self.fp8_linear.process_weights_after_loading(layer)
|
||||||
|
return
|
||||||
|
|
||||||
input_scale = None
|
input_scale = None
|
||||||
# TODO(rob): refactor block quant into separate class.
|
# TODO(rob): refactor block quant into separate class.
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert not self.act_q_static
|
assert not self.act_q_static
|
||||||
size_k_first = False
|
|
||||||
|
|
||||||
weight, weight_scale_inv = process_fp8_weight_block_strategy(
|
weight, weight_scale_inv = process_fp8_weight_block_strategy(
|
||||||
layer.weight, layer.weight_scale_inv
|
layer.weight, layer.weight_scale_inv
|
||||||
@@ -411,16 +410,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||||
# requantize the logical shards as a single weight.
|
# requantize the logical shards as a single weight.
|
||||||
if not self.use_marlin:
|
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
||||||
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
weight,
|
||||||
weight,
|
weight_scale,
|
||||||
weight_scale,
|
layer.logical_widths,
|
||||||
layer.logical_widths,
|
getattr(layer, "input_scale", None),
|
||||||
getattr(layer, "input_scale", None),
|
)
|
||||||
)
|
if self.act_q_static:
|
||||||
if self.act_q_static:
|
assert input_scale is not None
|
||||||
assert input_scale is not None
|
input_scale = input_scale.max()
|
||||||
input_scale = input_scale.max()
|
|
||||||
weight = weight.t()
|
weight = weight.t()
|
||||||
|
|
||||||
# Update layer with new values.
|
# Update layer with new values.
|
||||||
@@ -432,14 +430,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
else:
|
else:
|
||||||
layer.input_scale = None
|
layer.input_scale = None
|
||||||
|
|
||||||
if self.use_marlin:
|
|
||||||
prepare_fp8_layer_for_marlin(
|
|
||||||
layer, size_k_first, input_dtype=self.marlin_input_dtype
|
|
||||||
)
|
|
||||||
# Activations not quantized for marlin.
|
|
||||||
del layer.input_scale
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
maybe_post_process_fp8_weight_block(layer)
|
maybe_post_process_fp8_weight_block(layer)
|
||||||
|
|
||||||
@@ -486,21 +476,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
if self.block_quant:
|
return self.fp8_linear.apply_weights(layer, x, bias)
|
||||||
weight_scale = layer.weight_scale_inv
|
|
||||||
else:
|
|
||||||
weight_scale = layer.weight_scale
|
|
||||||
|
|
||||||
return apply_fp8_marlin_linear(
|
|
||||||
input=x,
|
|
||||||
weight=layer.weight,
|
|
||||||
weight_scale=weight_scale,
|
|
||||||
workspace=layer.workspace,
|
|
||||||
size_n=layer.output_size_per_partition,
|
|
||||||
size_k=layer.input_size_per_partition,
|
|
||||||
input_dtype=self.marlin_input_dtype,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert self.weight_block_size is not None
|
assert self.weight_block_size is not None
|
||||||
@@ -623,18 +599,20 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
|||||||
|
|
||||||
layer.input_scale = None
|
layer.input_scale = None
|
||||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||||
weight = qweight.t()
|
|
||||||
|
|
||||||
# Update layer with new values.
|
# Update layer with new values.
|
||||||
replace_parameter(layer, "weight", weight.data)
|
replace_parameter(layer, "weight", qweight.data)
|
||||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
size_k_first = True
|
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid
|
||||||
prepare_fp8_layer_for_marlin(
|
# AttributeError if backend selection changes.
|
||||||
layer, size_k_first, input_dtype=self.marlin_input_dtype
|
if hasattr(self.fp8_linear, "marlin_input_dtype"):
|
||||||
)
|
self.fp8_linear.marlin_input_dtype = self.marlin_input_dtype
|
||||||
# Activations not quantized for marlin.
|
self.fp8_linear.process_weights_after_loading(layer)
|
||||||
|
else:
|
||||||
|
weight = qweight.t()
|
||||||
|
replace_parameter(layer, "weight", weight.data)
|
||||||
|
|
||||||
# Prevent duplicate processing (e.g., during weight reload)
|
# Prevent duplicate processing (e.g., during weight reload)
|
||||||
layer._already_called_process_weights_after_loading = True
|
layer._already_called_process_weights_after_loading = True
|
||||||
|
|||||||
Reference in New Issue
Block a user