[Bugfix] Fix Phi-3 BNB quantization with tensor parallel (#9948)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-11-22 15:01:56 +08:00
committed by GitHub
parent a111d0151f
commit b6374e09b0
2 changed files with 56 additions and 6 deletions

View File

@@ -1,3 +1,4 @@
import itertools
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple
@@ -41,12 +42,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
def adjust_bitsandbytes_4bit_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]],
shard_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total, _ = qkv_offsets["total"]
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
total, _ = shard_offsets["total"]
orig_offset, orig_size = shard_offsets[loaded_shard_id]
quantized_total = param.data.shape[0]
quantized_offset = orig_offset * quantized_total // total
@@ -499,9 +500,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim] // 2
shard_offset = shard_size * shard_id
index = list(itertools.accumulate([0] + self.output_sizes))
orig_offsets = {
str(i): (index[i], size)
for i, size in enumerate(self.output_sizes)
}
orig_offsets["total"] = (self.output_size, 0)
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_offsets, str(shard_id))
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)