[Quantization][1/N] MoE support BNB-Inflight Quantization (#20061)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-07-11 16:01:13 +08:00
committed by GitHub
parent 762be26a8e
commit 8020e98c9f
8 changed files with 561 additions and 88 deletions

View File

@@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Any, Callable, Optional, Union
import torch
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
@@ -120,12 +123,15 @@ class BitsAndBytesConfig(QuantizationConfig):
llm_int8_skip_modules=llm_int8_skip_modules,
llm_int8_threshold=llm_int8_threshold)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["LinearMethodBase"]:
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]:
if isinstance(layer, LinearBase):
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self)
elif isinstance(layer, FusedMoE):
return BitsAndBytesMoEMethod(self)
return None
@@ -146,6 +152,13 @@ def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
return substr_check or prefix_check
def calculate_quant_ratio(dtype):
if dtype.is_floating_point:
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
else:
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
class BitsAndBytesLinearMethod(LinearMethodBase):
"""Linear method for BitsAndBytes.
@@ -173,12 +186,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
**extra_weight_attrs):
from bitsandbytes.nn import Int8Params
def calculate_quant_ratio(dtype):
if dtype.is_floating_point:
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
else:
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
def create_qweight_for_8bit():
qweight = Int8Params(
data=torch.empty(sum(output_partition_sizes),
@@ -394,3 +401,210 @@ try:
except AttributeError as error:
raise error
class BitsAndBytesMoEMethod(FusedMoEMethodBase):
"""MoE method for BitsAndBytes.
Args:
quant_config: The BitsAndBytes quantization config.
"""
def __init__(self, quant_config: BitsAndBytesConfig):
try:
import bitsandbytes
if bitsandbytes.__version__ < "0.45.3":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.45.3.")
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.45.3 via "
"`pip install bitsandbytes>=0.45.3` to use "
"bitsandbytes quantizer.") from err
self.topk_indices_dtype = None
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.quant_config.load_in_8bit:
call_fun = self._create_weights_8bit
else:
call_fun = self._create_weights_4bit
call_fun(
layer,
num_experts,
hidden_size,
intermediate_size_per_partition,
params_dtype,
**extra_weight_attrs,
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `BitsAndBytesMoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
if self.quant_config.load_in_8bit:
w13, w2 = self._apply_8bit_dequant(layer)
else:
w13, w2 = self._apply_4bit_dequnt(layer)
return fused_experts(
hidden_states=x,
w1=w13,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
def _create_weights_4bit(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
quant_ratio = calculate_quant_ratio(params_dtype)
# Fused gate_up_proj (column parallel)
w13_total_size = (hidden_size * 2 *
intermediate_size_per_partition) // quant_ratio
w13_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
w13_total_size,
1,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
set_weight_attrs(
w13_qweight,
{
"num_experts":
num_experts,
"input_dim":
hidden_size,
"output_dim":
2 * intermediate_size_per_partition,
"experts_shape": (
num_experts,
intermediate_size_per_partition * 2,
hidden_size,
),
"pack_factor":
quant_ratio,
"use_bitsandbytes_4bit":
True,
},
)
# down_proj (row parallel)
w2_total_size = (hidden_size *
intermediate_size_per_partition) // quant_ratio
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
w2_total_size,
1,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(
w2_qweight,
{
"num_experts":
num_experts,
"input_dim":
intermediate_size_per_partition,
"output_dim":
hidden_size,
"experts_shape": (
num_experts,
hidden_size,
intermediate_size_per_partition,
),
"pack_factor":
quant_ratio,
"use_bitsandbytes_4bit":
True,
},
)
layer.register_parameter("w2_weight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
def _create_weights_8bit(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
raise NotImplementedError
def _apply_4bit_dequnt(
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
from bitsandbytes.functional import dequantize_4bit
w13 = dequantize_4bit(
layer.w13_weight.reshape(-1, 1),
layer.w13_weight.bnb_quant_state,
)
w2 = dequantize_4bit(
layer.w2_weight.reshape(-1, 1),
layer.w2_weight.bnb_quant_state,
)
w13 = w13.reshape(layer.w13_weight.experts_shape)
w2 = w2.reshape(layer.w2_weight.experts_shape)
return w13, w2
def _apply_8bit_dequant(
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError