[Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation - BitBLAS (#6036)

Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com>
Co-authored-by: xinyuxiao <xinyuxiao2024@gmail.com>
This commit is contained in:
Lei Wang
2025-04-22 16:01:36 +08:00
committed by GitHub
parent c4ab9f3e71
commit 8d32dc603d
15 changed files with 1864 additions and 7 deletions

View File

@@ -282,10 +282,12 @@ class PackedColumnParameter(_ColumnvLLMParameter):
packed_factor: Union[int, Fraction],
packed_dim: int,
marlin_tile_size: Optional[int] = None,
bitblas_tile_size: Optional[int] = None,
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
self._marlin_tile_size = marlin_tile_size
self._bitblas_tile_size = bitblas_tile_size
super().__init__(**kwargs)
@property
@@ -300,12 +302,17 @@ class PackedColumnParameter(_ColumnvLLMParameter):
def marlin_tile_size(self):
return self._marlin_tile_size
@property
def bitblas_tile_size(self):
return self._bitblas_tile_size
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)
marlin_tile_size=self.marlin_tile_size,
bitblas_tile_size=self.bitblas_tile_size)
class PackedvLLMParameter(ModelWeightParameter):
@@ -323,10 +330,12 @@ class PackedvLLMParameter(ModelWeightParameter):
packed_factor: Union[int, Fraction],
packed_dim: int,
marlin_tile_size: Optional[int] = None,
bitblas_tile_size: Optional[int] = None,
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
self._marlin_tile_size = marlin_tile_size
self._bitblas_tile_size = bitblas_tile_size
super().__init__(**kwargs)
@property
@@ -341,12 +350,17 @@ class PackedvLLMParameter(ModelWeightParameter):
def marlin_tile_size(self):
return self._marlin_tile_size
@property
def bitblas_tile_size(self):
return self._bitblas_tile_size
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)
marlin_tile_size=self.marlin_tile_size,
bitblas_tile_size=self.bitblas_tile_size)
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
@@ -421,8 +435,13 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset,
bitblas_tile_size):
return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size
def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
marlin_tile_size):
marlin_tile_size, bitblas_tile_size):
shard_size = shard_size // packed_factor
shard_offset = shard_offset // packed_factor
if marlin_tile_size is not None:
@@ -430,4 +449,10 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
shard_size=shard_size,
shard_offset=shard_offset,
marlin_tile_size=marlin_tile_size)
return shard_size, shard_offset
elif bitblas_tile_size is not None:
return _adjust_shard_indexes_for_bitblas(
shard_size=shard_size,
shard_offset=shard_offset,
bitblas_tile_size=bitblas_tile_size)
return shard_size, shard_offset