[Model] Deepseek GGUF support (#13167)

This commit is contained in:
Szymon Ożóg
2025-02-27 11:08:35 +01:00
committed by GitHub
parent edf309ebbe
commit 7f0be2aa24
8 changed files with 198 additions and 10 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
import gguf
import torch
@@ -8,6 +8,9 @@ from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
@@ -29,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
@@ -49,6 +52,8 @@ class GGUFConfig(QuantizationConfig):
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self)
return None
@@ -184,6 +189,124 @@ class GGUFLinearMethod(LinearMethodBase):
return out
class GGUFMoEMethod(FusedMoEMethodBase):
"""MoE method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def __init__(self, quant_config: GGUFConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
tensor_shape = (num_experts, 2 * intermediate_size_per_partition,
hidden_size)
#gate up proj
w13_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w13_qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
})
set_weight_attrs(w13_qweight, extra_weight_attrs)
layer.register_parameter("w13_qweight", w13_qweight)
w13_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(w13_qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"ignore_warning": True
})
set_weight_attrs(w13_qweight_type, extra_weight_attrs)
layer.register_parameter("w13_qweight_type", w13_qweight_type)
tensor_shape = (num_experts, intermediate_size_per_partition,
hidden_size)
#gate down proj
w2_qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs(
w2_qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"data_container": [],
})
set_weight_attrs(w2_qweight, extra_weight_attrs)
layer.register_parameter("w2_qweight", w2_qweight)
w2_qweight_type = Parameter(torch.empty(1, dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(w2_qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"ignore_warning": True
})
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
layer.register_parameter("w2_qweight_type", w2_qweight_type)
self.act = SiluAndMul()
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
):
assert activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
final_hidden_states = torch.empty_like(x)
for tok, (w, idx) in enumerate(zip(topk_weights, topk_ids)):
inp = x[tok].reshape((1, ) + x.shape[1:])
current_hidden_state = None
for ww, ii in zip(w, idx):
expert_up = layer.w13_qweight[ii]
out = _fuse_mul_mat(inp, expert_up,
layer.w13_qweight_type.weight_type)
out = self.act(out)
expert_down = layer.w2_qweight[ii]
current_state = _fuse_mul_mat(
out, expert_down,
layer.w2_qweight_type.weight_type).mul_(ww)
if current_hidden_state is None:
current_hidden_state = current_state
else:
current_hidden_state.add_(current_state)
final_hidden_states[tok] = current_hidden_state
return final_hidden_states
class GGUFEmbeddingMethod(GGUFLinearMethod):
"""Embedding method for GGUF.