[Core] Refactor QKVCrossParallelLinear implementation to support BNB 4-bit quantization (#14545)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-03-12 11:12:52 +08:00
committed by GitHub
parent 77a318bd01
commit e392d85831
3 changed files with 233 additions and 64 deletions

View File

@@ -2,9 +2,10 @@
import itertools
from abc import abstractmethod
from typing import Optional, Union
from typing import Any, Literal, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
@@ -84,6 +85,43 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
return param[shard_id], loaded_weight
# TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
"""
Separate the BitsAndBytes 4-bit shard.
For example, given bnb weight attributes as below:
{
'bnb_shard_offsets': array([0, 4, 8, 16]),
'bnb_quant_state': {0: ..., 1: ..., 2: ...},
}
The function will return:
{
'bnb_shard_offsets': array([0, 4]),
'bnb_quant_state': {0: ...},
}
and
{
'bnb_shard_offsets': array([0, 4, 12]),
'bnb_quant_state': {0: ..., 1: ...},
}
"""
shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
offset_l = shard_offsets[:2]
offset_r = shard_offsets[1:] - shard_offsets[1]
quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
quant_state_r = {
i - 1: bnb_weight_attrs["bnb_quant_state"][i]
for i in range(1,
len(shard_offsets) - 1)
}
left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
return left, right
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@@ -1229,7 +1267,24 @@ class RowParallelLinear(LinearBase):
return s
class QKVCrossParallelLinear(torch.nn.Module):
class QKVCrossParallelLinear(LinearBase):
"""Linear layers for efficient cross-attention's QKV transformation.
Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(self,
hidden_size: int,
@@ -1241,12 +1296,28 @@ class QKVCrossParallelLinear(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# input_size and output_size are not used, just for alignment
input_size = hidden_size
output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
super().__init__(input_size=input_size,
output_size=output_size,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
self.quant_config = quant_config
# Empty placeholders for loading as a single module.
self.weight = torch.nn.Parameter()
set_weight_attrs(self.weight, {
"weight_loader": self.weight_loader_weight,
})
placeholder_size = 0
assert self.quant_method is not None
self.quant_method.create_weights(self,
placeholder_size, [placeholder_size],
placeholder_size,
placeholder_size,
self.params_dtype,
weight_loader=self.weight_loader)
# Use a dictionary to avoid submodules parameters auto-registration:
# drop-in replacement for a `QKVParallelLinear` module.
self.proj = dict()
@@ -1276,18 +1347,94 @@ class QKVCrossParallelLinear(torch.nn.Module):
if bias:
self.bias = torch.nn.Parameter()
set_weight_attrs(self.bias, {
"weight_loader": self.weight_loader_bias,
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.bias = None
@property
def q_proj_decoder(self):
return self.proj["q_proj_decoder"]
def q_proj_decoder(self) -> ColumnParallelLinear:
layer = self.proj["q_proj_decoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
return layer
@property
def kv_proj_encoder(self):
return self.proj["kv_proj_encoder"]
def kv_proj_encoder(self) -> QKVParallelLinear:
layer = self.proj["kv_proj_encoder"]
for name, param in self.named_parameters():
target_param = getattr(layer, name)
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
return layer
def forward(self, decoder_hidden_states, encoder_hidden_states):
def sync_weight_attrs(
self,
src_param: nn.Parameter,
tgt_param: nn.Parameter,
mode: Literal["q_proj_decoder", "kv_proj_encoder"],
):
missing_attrs_dict = {
k: getattr(src_param, k)
for k in (set(src_param.__dict__.keys()) -
set(tgt_param.__dict__.keys()))
}
# TODO(Isotr0py): handle bitsandbytes 8bit
use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
False)
if (missing_attrs_dict and use_bitsandbytes_4bit):
q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
missing_attrs_dict)
if mode == "q_proj_decoder":
set_weight_attrs(tgt_param, q_proj_attrs)
elif mode == "kv_proj_encoder":
set_weight_attrs(tgt_param, kv_proj_attrs)
else:
set_weight_attrs(tgt_param, missing_attrs_dict)
def _is_same_param(
self,
src_param: torch.nn.Parameter,
map_param: torch.nn.Parameter,
) -> bool:
"""Check if two parameters are exactly pointing to same things."""
# ignore weight_loader because it's always different
key_to_ignore = ["weight_loader", "_weight_loader"]
has_same_type_name = type(src_param) is type(map_param)
src_param_attrs = {
k: v
for k, v in src_param.__dict__.items() if k not in key_to_ignore
}
map_param_attrs = {
k: v
for k, v in map_param.__dict__.items() if k not in key_to_ignore
}
has_same_attrs = src_param_attrs == map_param_attrs
return has_same_type_name and has_same_attrs
def select_proj_params(
self,
layer: nn.Module,
param: nn.Parameter,
) -> nn.Parameter:
"""
Given the placeholder param,
return the corresponding param in the proj layers.
"""
target_param_list = [
v for _, v in layer.named_parameters()
if self._is_same_param(param, v)
]
assert len(target_param_list) == 1
target_param = target_param_list[0]
return target_param
def forward( # type: ignore[override]
self,
decoder_hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
q, _ = self.q_proj_decoder(decoder_hidden_states)
if encoder_hidden_states is None:
# Encoder KV already cached.
@@ -1300,25 +1447,21 @@ class QKVCrossParallelLinear(torch.nn.Module):
k, v = kv_enc.split(self.kv_size, dim=-1)
return q, k, v
def weight_loader_weight(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
else self.kv_proj_encoder.weight
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)
def weight_loader(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
layer = (self.q_proj_decoder
if loaded_shard_id == "q" else self.kv_proj_encoder)
target_param = self.select_proj_params(layer, param)
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
def weight_loader_bias(self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
else self.kv_proj_encoder.bias
param.weight_loader(
param,
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
param, loaded_weight, loaded_shard_id)
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", q_size={self.q_proj_decoder.output_size_per_partition}"
s += f", kv_size={self.kv_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += ", gather_output=False"
return s