[Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation - BitBLAS (#6036)
Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com> Co-authored-by: xinyuxiao <xinyuxiao2024@gmail.com>
This commit is contained in:
@@ -31,6 +31,8 @@ logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"CompressedTensorsLinearMethod",
|
||||
"BitBLASLinearMethod",
|
||||
"GPTQBitBLASLinearMethod",
|
||||
"AWQMarlinLinearMethod",
|
||||
"AWQLinearMethod",
|
||||
"GPTQMarlinLinearMethod",
|
||||
@@ -50,6 +52,15 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
]
|
||||
|
||||
|
||||
def adjust_bitblas_shard(param, shard_size, shard_offset):
|
||||
bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
|
||||
if bitblas_tile_size is not None:
|
||||
return (shard_size // bitblas_tile_size,
|
||||
shard_offset // bitblas_tile_size)
|
||||
|
||||
return shard_size, shard_offset
|
||||
|
||||
|
||||
def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
||||
if marlin_tile_size is None:
|
||||
@@ -615,6 +626,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
shard_size, shard_offset = adjust_bitblas_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
index = list(itertools.accumulate([0] + self.output_sizes))
|
||||
orig_offsets = {
|
||||
@@ -646,6 +660,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
shard_size, shard_offset = adjust_bitblas_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
|
||||
Reference in New Issue
Block a user