support bitsandbytes 8-bit and FP4 quantized models (#7445)

This commit is contained in:
chenqianfzh
2024-08-29 16:09:08 -07:00
committed by GitHub
parent 257afc37c5
commit 4664ceaad6
6 changed files with 435 additions and 189 deletions

View File

@@ -36,9 +36,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def adjust_bitsandbytes_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
def adjust_bitsandbytes_4bit_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total, _ = qkv_offsets["total"]
@@ -505,8 +505,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
@@ -858,8 +859,9 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (self.num_heads * self.head_size,
@@ -871,7 +873,7 @@ class QKVParallelLinear(ColumnParallelLinear):
((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
0)
}
shard_size, shard_offset = adjust_bitsandbytes_shard(
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id)
if is_gguf_weight: