[Core] Refactor QKVCrossParallelLinear implementation to support BNB 4-bit quantization (#14545)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -2,9 +2,10 @@
|
||||
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
@@ -84,6 +85,43 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
||||
return param[shard_id], loaded_weight
|
||||
|
||||
|
||||
# TODO(Isotr0py): We might need a more flexible structure to handle
|
||||
# bitsandbytes shard offsets.
|
||||
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
|
||||
"""
|
||||
Separate the BitsAndBytes 4-bit shard.
|
||||
|
||||
For example, given bnb weight attributes as below:
|
||||
{
|
||||
'bnb_shard_offsets': array([0, 4, 8, 16]),
|
||||
'bnb_quant_state': {0: ..., 1: ..., 2: ...},
|
||||
}
|
||||
|
||||
The function will return:
|
||||
{
|
||||
'bnb_shard_offsets': array([0, 4]),
|
||||
'bnb_quant_state': {0: ...},
|
||||
}
|
||||
and
|
||||
{
|
||||
'bnb_shard_offsets': array([0, 4, 12]),
|
||||
'bnb_quant_state': {0: ..., 1: ...},
|
||||
}
|
||||
"""
|
||||
shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
|
||||
offset_l = shard_offsets[:2]
|
||||
offset_r = shard_offsets[1:] - shard_offsets[1]
|
||||
quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
|
||||
quant_state_r = {
|
||||
i - 1: bnb_weight_attrs["bnb_quant_state"][i]
|
||||
for i in range(1,
|
||||
len(shard_offsets) - 1)
|
||||
}
|
||||
left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
|
||||
right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
|
||||
return left, right
|
||||
|
||||
|
||||
class LinearMethodBase(QuantizeMethodBase):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@@ -1229,7 +1267,24 @@ class RowParallelLinear(LinearBase):
|
||||
return s
|
||||
|
||||
|
||||
class QKVCrossParallelLinear(torch.nn.Module):
|
||||
class QKVCrossParallelLinear(LinearBase):
|
||||
"""Linear layers for efficient cross-attention's QKV transformation.
|
||||
|
||||
Args:
|
||||
hidden_size: input hidden state size of the transformer.
|
||||
head_size: size of each attention head.
|
||||
total_num_heads: total number of attention query heads.
|
||||
total_num_kv_heads: total number of attention key/value heads. If
|
||||
None, assume total_num_kv_heads = total_num_heads.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
@@ -1241,12 +1296,28 @@ class QKVCrossParallelLinear(torch.nn.Module):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
# input_size and output_size are not used, just for alignment
|
||||
input_size = hidden_size
|
||||
output_size = (total_num_heads + (total_num_kv_heads or 0)) * head_size
|
||||
super().__init__(input_size=input_size,
|
||||
output_size=output_size,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Empty placeholders for loading as a single module.
|
||||
self.weight = torch.nn.Parameter()
|
||||
set_weight_attrs(self.weight, {
|
||||
"weight_loader": self.weight_loader_weight,
|
||||
})
|
||||
placeholder_size = 0
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
placeholder_size, [placeholder_size],
|
||||
placeholder_size,
|
||||
placeholder_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
# Use a dictionary to avoid submodules parameters auto-registration:
|
||||
# drop-in replacement for a `QKVParallelLinear` module.
|
||||
self.proj = dict()
|
||||
@@ -1276,18 +1347,94 @@ class QKVCrossParallelLinear(torch.nn.Module):
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter()
|
||||
set_weight_attrs(self.bias, {
|
||||
"weight_loader": self.weight_loader_bias,
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
@property
|
||||
def q_proj_decoder(self):
|
||||
return self.proj["q_proj_decoder"]
|
||||
def q_proj_decoder(self) -> ColumnParallelLinear:
|
||||
layer = self.proj["q_proj_decoder"]
|
||||
for name, param in self.named_parameters():
|
||||
target_param = getattr(layer, name)
|
||||
self.sync_weight_attrs(param, target_param, mode="q_proj_decoder")
|
||||
return layer
|
||||
|
||||
@property
|
||||
def kv_proj_encoder(self):
|
||||
return self.proj["kv_proj_encoder"]
|
||||
def kv_proj_encoder(self) -> QKVParallelLinear:
|
||||
layer = self.proj["kv_proj_encoder"]
|
||||
for name, param in self.named_parameters():
|
||||
target_param = getattr(layer, name)
|
||||
self.sync_weight_attrs(param, target_param, mode="kv_proj_encoder")
|
||||
return layer
|
||||
|
||||
def forward(self, decoder_hidden_states, encoder_hidden_states):
|
||||
def sync_weight_attrs(
|
||||
self,
|
||||
src_param: nn.Parameter,
|
||||
tgt_param: nn.Parameter,
|
||||
mode: Literal["q_proj_decoder", "kv_proj_encoder"],
|
||||
):
|
||||
missing_attrs_dict = {
|
||||
k: getattr(src_param, k)
|
||||
for k in (set(src_param.__dict__.keys()) -
|
||||
set(tgt_param.__dict__.keys()))
|
||||
}
|
||||
# TODO(Isotr0py): handle bitsandbytes 8bit
|
||||
use_bitsandbytes_4bit = getattr(src_param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
if (missing_attrs_dict and use_bitsandbytes_4bit):
|
||||
q_proj_attrs, kv_proj_attrs = left_shift_bitsandbytes_4bit_shard(
|
||||
missing_attrs_dict)
|
||||
if mode == "q_proj_decoder":
|
||||
set_weight_attrs(tgt_param, q_proj_attrs)
|
||||
elif mode == "kv_proj_encoder":
|
||||
set_weight_attrs(tgt_param, kv_proj_attrs)
|
||||
else:
|
||||
set_weight_attrs(tgt_param, missing_attrs_dict)
|
||||
|
||||
def _is_same_param(
|
||||
self,
|
||||
src_param: torch.nn.Parameter,
|
||||
map_param: torch.nn.Parameter,
|
||||
) -> bool:
|
||||
"""Check if two parameters are exactly pointing to same things."""
|
||||
# ignore weight_loader because it's always different
|
||||
key_to_ignore = ["weight_loader", "_weight_loader"]
|
||||
has_same_type_name = type(src_param) is type(map_param)
|
||||
src_param_attrs = {
|
||||
k: v
|
||||
for k, v in src_param.__dict__.items() if k not in key_to_ignore
|
||||
}
|
||||
map_param_attrs = {
|
||||
k: v
|
||||
for k, v in map_param.__dict__.items() if k not in key_to_ignore
|
||||
}
|
||||
has_same_attrs = src_param_attrs == map_param_attrs
|
||||
return has_same_type_name and has_same_attrs
|
||||
|
||||
def select_proj_params(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
param: nn.Parameter,
|
||||
) -> nn.Parameter:
|
||||
"""
|
||||
Given the placeholder param,
|
||||
return the corresponding param in the proj layers.
|
||||
"""
|
||||
target_param_list = [
|
||||
v for _, v in layer.named_parameters()
|
||||
if self._is_same_param(param, v)
|
||||
]
|
||||
assert len(target_param_list) == 1
|
||||
target_param = target_param_list[0]
|
||||
return target_param
|
||||
|
||||
def forward( # type: ignore[override]
|
||||
self,
|
||||
decoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
q, _ = self.q_proj_decoder(decoder_hidden_states)
|
||||
if encoder_hidden_states is None:
|
||||
# Encoder KV already cached.
|
||||
@@ -1300,25 +1447,21 @@ class QKVCrossParallelLinear(torch.nn.Module):
|
||||
k, v = kv_enc.split(self.kv_size, dim=-1)
|
||||
return q, k, v
|
||||
|
||||
def weight_loader_weight(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
# NOTE Use QKV/ColumnParallel weight_loader, ignore placeholder param.
|
||||
param = self.q_proj_decoder.weight if loaded_shard_id == "q" \
|
||||
else self.kv_proj_encoder.weight
|
||||
param.weight_loader(
|
||||
param,
|
||||
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
|
||||
param, loaded_weight, loaded_shard_id)
|
||||
def weight_loader(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
layer = (self.q_proj_decoder
|
||||
if loaded_shard_id == "q" else self.kv_proj_encoder)
|
||||
target_param = self.select_proj_params(layer, param)
|
||||
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
|
||||
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
|
||||
|
||||
def weight_loader_bias(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
param = self.q_proj_decoder.bias if loaded_shard_id == "q" \
|
||||
else self.kv_proj_encoder.bias
|
||||
param.weight_loader(
|
||||
param,
|
||||
loaded_weight) if loaded_shard_id == "q" else param.weight_loader(
|
||||
param, loaded_weight, loaded_shard_id)
|
||||
def extra_repr(self) -> str:
|
||||
s = f"in_features={self.input_size}"
|
||||
s += f", q_size={self.q_proj_decoder.output_size_per_partition}"
|
||||
s += f", kv_size={self.kv_size}"
|
||||
s += f", bias={self.bias is not None}"
|
||||
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
||||
s += ", gather_output=False"
|
||||
return s
|
||||
|
||||
Reference in New Issue
Block a user