[Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation - BitBLAS (#6036)
Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com> Co-authored-by: xinyuxiao <xinyuxiao2024@gmail.com>
This commit is contained in:
@@ -282,10 +282,12 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
bitblas_tile_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
self._marlin_tile_size = marlin_tile_size
|
||||
self._bitblas_tile_size = bitblas_tile_size
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
@@ -300,12 +302,17 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
||||
def marlin_tile_size(self):
|
||||
return self._marlin_tile_size
|
||||
|
||||
@property
|
||||
def bitblas_tile_size(self):
|
||||
return self._bitblas_tile_size
|
||||
|
||||
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
|
||||
return _adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
packed_factor=self.packed_factor,
|
||||
marlin_tile_size=self.marlin_tile_size)
|
||||
marlin_tile_size=self.marlin_tile_size,
|
||||
bitblas_tile_size=self.bitblas_tile_size)
|
||||
|
||||
|
||||
class PackedvLLMParameter(ModelWeightParameter):
|
||||
@@ -323,10 +330,12 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
bitblas_tile_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
self._marlin_tile_size = marlin_tile_size
|
||||
self._bitblas_tile_size = bitblas_tile_size
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
@@ -341,12 +350,17 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
def marlin_tile_size(self):
|
||||
return self._marlin_tile_size
|
||||
|
||||
@property
|
||||
def bitblas_tile_size(self):
|
||||
return self._bitblas_tile_size
|
||||
|
||||
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
|
||||
return _adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
packed_factor=self.packed_factor,
|
||||
marlin_tile_size=self.marlin_tile_size)
|
||||
marlin_tile_size=self.marlin_tile_size,
|
||||
bitblas_tile_size=self.bitblas_tile_size)
|
||||
|
||||
|
||||
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
@@ -421,8 +435,13 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset,
|
||||
bitblas_tile_size):
|
||||
return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
|
||||
marlin_tile_size):
|
||||
marlin_tile_size, bitblas_tile_size):
|
||||
shard_size = shard_size // packed_factor
|
||||
shard_offset = shard_offset // packed_factor
|
||||
if marlin_tile_size is not None:
|
||||
@@ -430,4 +449,10 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
marlin_tile_size=marlin_tile_size)
|
||||
return shard_size, shard_offset
|
||||
elif bitblas_tile_size is not None:
|
||||
return _adjust_shard_indexes_for_bitblas(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
bitblas_tile_size=bitblas_tile_size)
|
||||
|
||||
return shard_size, shard_offset
|
||||
Reference in New Issue
Block a user