[Feature][Kernel] Support bitsandbytes quantization and QLoRA (#4776)

This commit is contained in:
chenqianfzh
2024-06-01 13:51:10 -07:00
committed by GitHub
parent 37464a0f74
commit b9c0605a8e
11 changed files with 752 additions and 8 deletions

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Optional
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
@@ -26,6 +26,21 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def adjust_bitsandbytes_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total, _ = qkv_offsets["total"]
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
quantized_total = param.data.shape[0]
quantized_offset = orig_offset * quantized_total // total
quantized_size = orig_size * quantized_total // total
return quantized_size, quantized_offset
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@@ -37,7 +52,7 @@ class LinearMethodBase(QuantizeMethodBase):
**extra_weight_attrs):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
@@ -416,6 +431,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
@@ -615,6 +636,22 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (self.num_heads * self.head_size,
self.num_kv_heads * self.head_size),
"v":
((self.num_heads + self.num_kv_heads) * self.head_size,
self.num_kv_heads * self.head_size),
"total":
((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
0)
}
shard_size, shard_offset = adjust_bitsandbytes_shard(
param, orig_qkv_offsets, loaded_shard_id)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":