[Quantization][1/N] MoE support BNB-Inflight Quantization (#20061)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
@@ -120,12 +123,15 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
llm_int8_skip_modules=llm_int8_skip_modules,
|
||||
llm_int8_threshold=llm_int8_threshold)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["LinearMethodBase"]:
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
return BitsAndBytesLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return BitsAndBytesMoEMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
@@ -146,6 +152,13 @@ def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
|
||||
return substr_check or prefix_check
|
||||
|
||||
|
||||
def calculate_quant_ratio(dtype):
|
||||
if dtype.is_floating_point:
|
||||
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
else:
|
||||
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
|
||||
|
||||
class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
"""Linear method for BitsAndBytes.
|
||||
|
||||
@@ -173,12 +186,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
**extra_weight_attrs):
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
def calculate_quant_ratio(dtype):
|
||||
if dtype.is_floating_point:
|
||||
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
else:
|
||||
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
|
||||
def create_qweight_for_8bit():
|
||||
qweight = Int8Params(
|
||||
data=torch.empty(sum(output_partition_sizes),
|
||||
@@ -394,3 +401,210 @@ try:
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for BitsAndBytes.
|
||||
|
||||
Args:
|
||||
quant_config: The BitsAndBytes quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: BitsAndBytesConfig):
|
||||
try:
|
||||
import bitsandbytes
|
||||
if bitsandbytes.__version__ < "0.45.3":
|
||||
raise ImportError("bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.45.3.")
|
||||
except ImportError as err:
|
||||
raise ImportError("Please install bitsandbytes>=0.45.3 via "
|
||||
"`pip install bitsandbytes>=0.45.3` to use "
|
||||
"bitsandbytes quantizer.") from err
|
||||
self.topk_indices_dtype = None
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if self.quant_config.load_in_8bit:
|
||||
call_fun = self._create_weights_8bit
|
||||
else:
|
||||
call_fun = self._create_weights_4bit
|
||||
call_fun(
|
||||
layer,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `BitsAndBytesMoEMethod` yet.")
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
if self.quant_config.load_in_8bit:
|
||||
w13, w2 = self._apply_8bit_dequant(layer)
|
||||
else:
|
||||
w13, w2 = self._apply_4bit_dequnt(layer)
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=w13,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
def _create_weights_4bit(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
quant_ratio = calculate_quant_ratio(params_dtype)
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_total_size = (hidden_size * 2 *
|
||||
intermediate_size_per_partition) // quant_ratio
|
||||
w13_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
w13_total_size,
|
||||
1,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
set_weight_attrs(
|
||||
w13_qweight,
|
||||
{
|
||||
"num_experts":
|
||||
num_experts,
|
||||
"input_dim":
|
||||
hidden_size,
|
||||
"output_dim":
|
||||
2 * intermediate_size_per_partition,
|
||||
"experts_shape": (
|
||||
num_experts,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size,
|
||||
),
|
||||
"pack_factor":
|
||||
quant_ratio,
|
||||
"use_bitsandbytes_4bit":
|
||||
True,
|
||||
},
|
||||
)
|
||||
# down_proj (row parallel)
|
||||
w2_total_size = (hidden_size *
|
||||
intermediate_size_per_partition) // quant_ratio
|
||||
w2_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
w2_total_size,
|
||||
1,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w2_qweight,
|
||||
{
|
||||
"num_experts":
|
||||
num_experts,
|
||||
"input_dim":
|
||||
intermediate_size_per_partition,
|
||||
"output_dim":
|
||||
hidden_size,
|
||||
"experts_shape": (
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
),
|
||||
"pack_factor":
|
||||
quant_ratio,
|
||||
"use_bitsandbytes_4bit":
|
||||
True,
|
||||
},
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
def _create_weights_8bit(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def _apply_4bit_dequnt(
|
||||
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from bitsandbytes.functional import dequantize_4bit
|
||||
w13 = dequantize_4bit(
|
||||
layer.w13_weight.reshape(-1, 1),
|
||||
layer.w13_weight.bnb_quant_state,
|
||||
)
|
||||
w2 = dequantize_4bit(
|
||||
layer.w2_weight.reshape(-1, 1),
|
||||
layer.w2_weight.bnb_quant_state,
|
||||
)
|
||||
w13 = w13.reshape(layer.w13_weight.experts_shape)
|
||||
w2 = w2.reshape(layer.w2_weight.experts_shape)
|
||||
return w13, w2
|
||||
|
||||
def _apply_8bit_dequant(
|
||||
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user