[Hardware/NVIDIA/Kernel] Enable nvidia/DeepSeek-R1-FP4 Model (#16362)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -9,6 +9,8 @@ from torch.nn.parameter import Parameter
|
||||
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
||||
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -210,25 +212,37 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
"`hf_quant_config.json` file for your model's "
|
||||
"quant configuration.")
|
||||
is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)
|
||||
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
||||
group_size = quant_config["group_size"]
|
||||
exclude_modules = quant_config["exclude_modules"]
|
||||
if not (group_size and kv_cache_quant_algo and exclude_modules):
|
||||
if ("group_size" and "kv_cache_quant_algo"
|
||||
and "exclude_modules") not in quant_config:
|
||||
raise ValueError("NVFP4 quantization requires group size and "
|
||||
"kv_cache_quant_algo specified in "
|
||||
"hf_quant_config.json")
|
||||
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
|
||||
group_size = quant_config["group_size"]
|
||||
exclude_modules = quant_config["exclude_modules"]
|
||||
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
||||
exclude_modules, group_size)
|
||||
|
||||
def is_layer_excluded(self, prefix: str, exclude_modules: List):
|
||||
import re
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.exclude_modules):
|
||||
if (is_layer_skipped(prefix, self.exclude_modules)
|
||||
or self.is_layer_excluded(prefix, self.exclude_modules)):
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptNvFp4LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ModelOptNvFp4FusedMoE(self)
|
||||
return None
|
||||
|
||||
|
||||
@@ -409,3 +423,235 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
|
||||
|
||||
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"""
|
||||
MoE Method for FP4 Quantization.
|
||||
Args:
|
||||
quant_config: NVFP4 Quant Config
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
if not self.quant_config.is_checkpoint_nvfp4_serialized:
|
||||
raise ValueError("NVFP4 quantization was selected, "
|
||||
" dynamic quantization is not supported.")
|
||||
|
||||
layer.quant_config = self.quant_config
|
||||
weight_dtype = torch.uint8
|
||||
weight_scale_dtype = torch.float8_e4m3fn
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
# GEMM 1
|
||||
w13_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
|
||||
# GEMM 2
|
||||
w2_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=weight_dtype),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
|
||||
w13_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
hidden_size // self.quant_config.group_size,
|
||||
dtype=weight_scale_dtype),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
# 2 fp4 items are packed in the input dimension
|
||||
intermediate_size_per_partition //
|
||||
self.quant_config.group_size,
|
||||
dtype=weight_scale_dtype),
|
||||
input_dim=1,
|
||||
output_dim=2,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
|
||||
|
||||
w13_weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, 2, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
||||
|
||||
w2_weight_scale_2 = PerTensorScaleParameter(
|
||||
data=torch.empty(num_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
|
||||
w13_input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
num_experts, 2, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
|
||||
w2_input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
num_experts, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
|
||||
def swizzle_blockscale(self, scale: torch.tensor):
|
||||
assert (scale.dtype == torch.float8_e4m3fn)
|
||||
# Pad and blockwise interleave weight_scale
|
||||
scale_ndim = scale.ndim
|
||||
if scale.ndim == 2:
|
||||
scale = scale.unsqueeze(0)
|
||||
assert scale.ndim == 3
|
||||
B, M, K = scale.shape
|
||||
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||
M_padded = round_up_multiple(M, 128)
|
||||
K_padded = round_up_multiple(K, 4)
|
||||
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||
padded_scale[:B, :M, :K] = scale
|
||||
batches, rows, cols = padded_scale.shape
|
||||
assert rows % 128 == 0
|
||||
assert cols % 4 == 0
|
||||
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||
cols // 4, 4)
|
||||
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||
return (swizzled_scale.reshape(M, K)
|
||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# GEMM 1
|
||||
|
||||
assert torch.allclose(
|
||||
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), (
|
||||
"Expected w1_weight_scale_2 to equal w3_weight_scale_2")
|
||||
|
||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
|
||||
requires_grad=False)
|
||||
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
|
||||
torch.float32)
|
||||
layer.g1_alphas = Parameter(
|
||||
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w13_blockscale_swizzled = self.swizzle_blockscale(
|
||||
layer.w13_weight_scale)
|
||||
|
||||
layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w13_input_scale_quant = Parameter(
|
||||
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
|
||||
|
||||
# GEMM 2
|
||||
layer.g2_alphas = Parameter(
|
||||
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
# This is for quantization, so we need to invert it.
|
||||
layer.w2_input_scale_quant = Parameter(
|
||||
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False)
|
||||
|
||||
assert (layer.w2_weight_scale.shape[2] % 16 == 0), (
|
||||
"Expected weight_scale.dim(1) to be divisible by 16")
|
||||
assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Blockscale must be represented as FP8-E4M3")
|
||||
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
|
||||
|
||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
):
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert not apply_router_weight_on_input, (
|
||||
"Router weight on input is not "
|
||||
"supported for ModelOptNvFp4FusedMoE.")
|
||||
assert expert_map is None, ("Expert Parallelism /expert_map "
|
||||
"is currently not supported for "
|
||||
"ModelOptNvFp4FusedMoE.")
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4)
|
||||
|
||||
# Cutlass moe takes in activations in BF16/Half precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
return cutlass_moe_fp4(a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alphas=layer.g1_alphas,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alphas=layer.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
device=x.device).to(x.dtype)
|
||||
|
||||
Reference in New Issue
Block a user