[Model] [Quantization] Support deepseek_v3 w8a8 fp8 block-wise quantization (#11523)
Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: simon-mo <simon.mo@hey.com> Signed-off-by: simon-mo <xmo@berkeley.edu> Co-authored-by: simon-mo <simon.mo@hey.com> Co-authored-by: simon-mo <xmo@berkeley.edu> Co-authored-by: HandH1998 <1335248067@qq.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
@@ -14,6 +15,8 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_w8a8_block_fp8_linear)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@@ -22,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||
from vllm.model_executor.parameter import (BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
@@ -41,6 +45,7 @@ class Fp8Config(QuantizationConfig):
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: Optional[List[str]] = None,
|
||||
weight_block_size: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
if is_checkpoint_fp8_serialized:
|
||||
@@ -51,6 +56,20 @@ class Fp8Config(QuantizationConfig):
|
||||
f"Unsupported activation scheme {activation_scheme}")
|
||||
self.activation_scheme = activation_scheme
|
||||
self.ignored_layers = ignored_layers or []
|
||||
if weight_block_size is not None:
|
||||
if not is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"The block-wise quantization only supports fp8-serialized "
|
||||
"checkpoint for now.")
|
||||
if len(weight_block_size) != 2:
|
||||
raise ValueError(
|
||||
"The quantization block size of weight must have 2 "
|
||||
f"dimensions, but got {len(weight_block_size)} dimensions")
|
||||
if activation_scheme != "dynamic":
|
||||
raise ValueError("The block-wise quantization only supports "
|
||||
"dynamic activation scheme for now, but got "
|
||||
f"{activation_scheme} activation scheme.")
|
||||
self.weight_block_size = weight_block_size
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
@@ -74,9 +93,12 @@ class Fp8Config(QuantizationConfig):
|
||||
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"],
|
||||
None)
|
||||
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||
activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers)
|
||||
ignored_layers=ignored_layers,
|
||||
weight_block_size=weight_block_size)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
@@ -123,6 +145,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
if self.block_quant:
|
||||
# Marlin doesn't support block-wise fp8
|
||||
self.use_marlin = False
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -133,10 +160,34 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
if self.block_quant:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
self.quant_config.weight_block_size[1],
|
||||
)
|
||||
# Required by row parallel
|
||||
if (tp_size > 1
|
||||
and input_size // input_size_per_partition == tp_size
|
||||
and input_size_per_partition % block_k != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
# Required by column parallel or enabling merged weights
|
||||
if (tp_size > 1 and output_size // output_size_per_partition
|
||||
== tp_size) or len(output_partition_sizes) > 1:
|
||||
for output_partition_size in output_partition_sizes:
|
||||
if output_partition_size % block_n != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_partition_size = "
|
||||
f"{output_partition_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
@@ -161,12 +212,29 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
if not self.block_quant:
|
||||
scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
else:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
scale = BlockQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
(output_size_per_partition + block_n - 1) // block_n,
|
||||
(input_size_per_partition + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
@@ -180,6 +248,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
return
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
@@ -266,6 +337,17 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
block_size=self.quant_config.weight_block_size,
|
||||
weight_scale=layer.weight_scale_inv,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return apply_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
@@ -291,6 +373,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.block_quant = self.quant_config.weight_block_size is not None
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size: int, params_dtype: torch.dtype,
|
||||
@@ -298,6 +381,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
if self.block_quant:
|
||||
assert self.quant_config.weight_block_size is not None
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
block_n, block_k = (
|
||||
self.quant_config.weight_block_size[0],
|
||||
self.quant_config.weight_block_size[1],
|
||||
)
|
||||
# NOTE: To ensure proper alignment of the block-wise quantization
|
||||
# scales, the output_size of the weights for both the gate and up
|
||||
# layers must be divisible by block_n.
|
||||
# Required by column parallel or enabling merged weights
|
||||
if intermediate_size % block_n != 0:
|
||||
raise ValueError(
|
||||
f"The output_size of gate's and up's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"weight quantization block_n = {block_n}.")
|
||||
if (tp_size > 1 and intermediate_size % block_k != 0):
|
||||
# Required by row parallel
|
||||
raise ValueError(f"The input_size of down's weight = "
|
||||
f"{intermediate_size} is not divisible by "
|
||||
f"weight quantization block_k = {block_k}.")
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
@@ -317,21 +421,45 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
2,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
if not self.block_quant:
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, 2, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
else:
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * ((intermediate_size + block_n - 1) // block_n),
|
||||
(hidden_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
(hidden_size + block_n - 1) // block_n,
|
||||
(intermediate_size + block_k - 1) // block_k,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.
|
||||
value} if self.block_quant else
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
@@ -364,7 +492,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
|
||||
# Block quant doesn't need to process weights after loading
|
||||
if self.block_quant:
|
||||
return
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
@@ -489,17 +619,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function)
|
||||
|
||||
return fused_experts(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
|
||||
Reference in New Issue
Block a user