[Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation - BitBLAS (#6036)

Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com>
Co-authored-by: xinyuxiao <xinyuxiao2024@gmail.com>
This commit is contained in:
Lei Wang
2025-04-22 16:01:36 +08:00
committed by GitHub
parent c4ab9f3e71
commit 8d32dc603d
15 changed files with 1864 additions and 7 deletions

View File

@@ -31,6 +31,8 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod",
"BitBLASLinearMethod",
"GPTQBitBLASLinearMethod",
"AWQMarlinLinearMethod",
"AWQLinearMethod",
"GPTQMarlinLinearMethod",
@@ -50,6 +52,15 @@ WEIGHT_LOADER_V2_SUPPORTED = [
]
def adjust_bitblas_shard(param, shard_size, shard_offset):
bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
if bitblas_tile_size is not None:
return (shard_size // bitblas_tile_size,
shard_offset // bitblas_tile_size)
return shard_size, shard_offset
def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
if marlin_tile_size is None:
@@ -615,6 +626,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
shard_size, shard_offset = adjust_bitblas_shard(
param, shard_size, shard_offset)
if use_bitsandbytes_4bit:
index = list(itertools.accumulate([0] + self.output_sizes))
orig_offsets = {
@@ -646,6 +660,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
shard_size, shard_offset = adjust_bitblas_shard(
param, shard_size, shard_offset)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)