Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,14 +9,21 @@ from weakref import WeakValueDictionary
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
__all__ = [
|
||||
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
|
||||
"ModelWeightParameter", "ChannelQuantScaleParameter",
|
||||
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter"
|
||||
"BasevLLMParameter",
|
||||
"PackedvLLMParameter",
|
||||
"PerTensorScaleParameter",
|
||||
"ModelWeightParameter",
|
||||
"ChannelQuantScaleParameter",
|
||||
"GroupQuantScaleParameter",
|
||||
"PackedColumnParameter",
|
||||
"RowvLLMParameter",
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -30,7 +37,6 @@ class BasevLLMParameter(Parameter):
|
||||
"""
|
||||
|
||||
def __new__(cls, data: Optional[torch.Tensor], **kwargs):
|
||||
|
||||
return super().__new__(cls, data=data, requires_grad=False)
|
||||
|
||||
def __init__(self, data: torch.Tensor, weight_loader: Callable):
|
||||
@@ -52,9 +58,9 @@ class BasevLLMParameter(Parameter):
|
||||
# This sometimes causes OOM errors during model loading. To avoid this,
|
||||
# we sync the param tensor after its weight loader is called.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.use_sync_weight_loader():
|
||||
weight_loader = current_platform.make_synced_weight_loader(
|
||||
weight_loader)
|
||||
weight_loader = current_platform.make_synced_weight_loader(weight_loader)
|
||||
|
||||
self._weight_loader = weight_loader
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
@@ -67,8 +73,9 @@ class BasevLLMParameter(Parameter):
|
||||
# weight loading should be implemented via Model.load_weights. In the
|
||||
# meantime, support deleting and overriding `weight_loader`` attribute
|
||||
if self._weight_loader is None:
|
||||
raise AttributeError(f"{self.__class__.__name__} weight_loader "
|
||||
"attribute has been deleted")
|
||||
raise AttributeError(
|
||||
f"{self.__class__.__name__} weight_loader attribute has been deleted"
|
||||
)
|
||||
return self._weight_loader
|
||||
|
||||
@weight_loader.setter
|
||||
@@ -82,11 +89,12 @@ class BasevLLMParameter(Parameter):
|
||||
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
|
||||
cond1 = self.data.ndim == 1 and self.data.numel() == 1
|
||||
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
|
||||
return (cond1 and cond2)
|
||||
return cond1 and cond2
|
||||
|
||||
def _assert_and_load(self, loaded_weight: torch.Tensor):
|
||||
assert (self.data.shape == loaded_weight.shape
|
||||
or self._is_1d_and_scalar(loaded_weight))
|
||||
assert self.data.shape == loaded_weight.shape or self._is_1d_and_scalar(
|
||||
loaded_weight
|
||||
)
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
@@ -121,11 +129,11 @@ class BasevLLMParameter(Parameter):
|
||||
|
||||
class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
"""
|
||||
Private class defining weight loading functionality
|
||||
Private class defining weight loading functionality
|
||||
(load_merged_column_weight, load_qkv_weight)
|
||||
for parameters being loaded into linear layers with column
|
||||
parallelism. This includes QKV and MLP layers which are
|
||||
not already fused on disk. Requires an output dimension
|
||||
not already fused on disk. Requires an output dimension
|
||||
to be defined. Called within the weight loader of
|
||||
each of the column parallel linear layers.
|
||||
"""
|
||||
@@ -140,57 +148,55 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
|
||||
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
shard_size = self.data.shape[self.output_dim]
|
||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||
self.tp_rank * shard_size,
|
||||
shard_size)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, self.tp_rank * shard_size, shard_size
|
||||
)
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
|
||||
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
|
||||
if isinstance(
|
||||
self,
|
||||
(PackedColumnParameter,
|
||||
PackedvLLMParameter)) and self.packed_dim == self.output_dim:
|
||||
if (
|
||||
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
|
||||
and self.packed_dim == self.output_dim
|
||||
):
|
||||
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||
shard_offset=shard_offset, shard_size=shard_size)
|
||||
shard_offset=shard_offset, shard_size=shard_size
|
||||
)
|
||||
|
||||
param_data = self.data
|
||||
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset,
|
||||
shard_size)
|
||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||
self.tp_rank * shard_size,
|
||||
shard_size)
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, self.tp_rank * shard_size, shard_size
|
||||
)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
shard_id = kwargs.get("shard_id")
|
||||
num_heads = kwargs.get("num_heads")
|
||||
|
||||
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
|
||||
if isinstance(
|
||||
self,
|
||||
(PackedColumnParameter,
|
||||
PackedvLLMParameter)) and self.output_dim == self.packed_dim:
|
||||
if (
|
||||
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
|
||||
and self.output_dim == self.packed_dim
|
||||
):
|
||||
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||
shard_offset=shard_offset, shard_size=shard_size)
|
||||
shard_offset=shard_offset, shard_size=shard_size
|
||||
)
|
||||
|
||||
param_data = self.data
|
||||
shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank //
|
||||
num_heads)
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset,
|
||||
shard_size)
|
||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||
shard_id * shard_size, shard_size)
|
||||
shard_id = self.tp_rank if shard_id == "q" else self.tp_rank // num_heads
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, shard_id * shard_size, shard_size
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
@@ -214,9 +220,9 @@ class RowvLLMParameter(BasevLLMParameter):
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
shard_size = self.data.shape[self.input_dim]
|
||||
loaded_weight = loaded_weight.narrow(self.input_dim,
|
||||
self.tp_rank * shard_size,
|
||||
shard_size)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.input_dim, self.tp_rank * shard_size, shard_size
|
||||
)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
@@ -230,6 +236,7 @@ class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
Parameter class for linear layer weights. Uses both column and
|
||||
row parallelism.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -238,6 +245,7 @@ class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
Parameter class for weight scales loaded for weights with
|
||||
grouped quantization. Uses both column and row parallelism.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -246,6 +254,7 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter):
|
||||
Parameter class for weight scales loaded for weights with
|
||||
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -256,11 +265,11 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
layers (e.g. for QKV, there are 3 scales loaded from disk).
|
||||
This is relevant to weights with per-tensor quantization.
|
||||
Adds functionality to map the scalers to a shard during
|
||||
weight loading.
|
||||
weight loading.
|
||||
|
||||
Note: additional parameter manipulation may be handled
|
||||
for each quantization config specifically, within
|
||||
process_weights_after_loading
|
||||
Note: additional parameter manipulation may be handled
|
||||
for each quantization config specifically, within
|
||||
process_weights_after_loading
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -280,10 +289,11 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
def load_column_parallel_weight(self, *args, **kwargs):
|
||||
super().load_row_parallel_weight(*args, **kwargs)
|
||||
|
||||
def _load_into_shard_id(self, loaded_weight: torch.Tensor,
|
||||
shard_id: Union[str, int], **kwargs):
|
||||
def _load_into_shard_id(
|
||||
self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs
|
||||
):
|
||||
"""
|
||||
Slice the parameter data based on the shard id for
|
||||
Slice the parameter data based on the shard id for
|
||||
loading.
|
||||
"""
|
||||
|
||||
@@ -308,12 +318,14 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
||||
for more details on the packed properties.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
bitblas_tile_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
bitblas_tile_size: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
self._marlin_tile_size = marlin_tile_size
|
||||
@@ -342,7 +354,8 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
||||
shard_offset=shard_offset,
|
||||
packed_factor=self.packed_factor,
|
||||
marlin_tile_size=self.marlin_tile_size,
|
||||
bitblas_tile_size=self.bitblas_tile_size)
|
||||
bitblas_tile_size=self.bitblas_tile_size,
|
||||
)
|
||||
|
||||
|
||||
class PackedvLLMParameter(ModelWeightParameter):
|
||||
@@ -351,17 +364,19 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
|
||||
Extends the ModelWeightParameter to take in the
|
||||
packed factor, the packed dimension, and optionally, marlin
|
||||
tile size for marlin kernels. Adjusts the shard_size and
|
||||
tile size for marlin kernels. Adjusts the shard_size and
|
||||
shard_offset for fused linear layers model weight loading
|
||||
by accounting for packing and optionally, marlin tile size.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
bitblas_tile_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
bitblas_tile_size: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
self._marlin_tile_size = marlin_tile_size
|
||||
@@ -390,7 +405,8 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
shard_offset=shard_offset,
|
||||
packed_factor=self.packed_factor,
|
||||
marlin_tile_size=self.marlin_tile_size,
|
||||
bitblas_tile_size=self.bitblas_tile_size)
|
||||
bitblas_tile_size=self.bitblas_tile_size,
|
||||
)
|
||||
|
||||
|
||||
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
@@ -410,6 +426,7 @@ class SharedWeightParameter(BasevLLMParameter):
|
||||
`MergedColumnParallelLinear`, the transform weights must stay separate
|
||||
tensors in order to allow for tensor memory sharing between layers.
|
||||
"""
|
||||
|
||||
# global registry for sharing tensors based on passed `data_key`
|
||||
# this dict holds weaksrefs to avoid memory leak after model cleanup
|
||||
tensors_registry: WeakValueDictionary = WeakValueDictionary()
|
||||
@@ -426,8 +443,7 @@ class SharedWeightParameter(BasevLLMParameter):
|
||||
return super().__new__(cls, data=None, **kwargs)
|
||||
|
||||
def __init__(self, input_dim: int = 1, output_dim: int = 0, **kwargs):
|
||||
weight_loader: Callable = kwargs.get(
|
||||
"weight_loader") # type: ignore[assignment]
|
||||
weight_loader: Callable = kwargs.get("weight_loader") # type: ignore[assignment]
|
||||
super().__init__(data=None, weight_loader=weight_loader)
|
||||
|
||||
self.local_tensors = set()
|
||||
@@ -435,12 +451,14 @@ class SharedWeightParameter(BasevLLMParameter):
|
||||
self.kwargs = {
|
||||
"input_dim": input_dim,
|
||||
"output_dim": output_dim,
|
||||
"weight_loader": self._fake_weight_loader
|
||||
"weight_loader": self._fake_weight_loader,
|
||||
}
|
||||
|
||||
if self.tp_size > 1:
|
||||
raise NotImplementedError(f"{self.__class__.__name__} does not "
|
||||
"currently support tensor parallelism")
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not "
|
||||
"currently support tensor parallelism"
|
||||
)
|
||||
|
||||
def add_partition(self, index: int, data_key: Hashable, *args, **kwargs):
|
||||
"""
|
||||
@@ -460,8 +478,7 @@ class SharedWeightParameter(BasevLLMParameter):
|
||||
data = self.tensors_registry[data_key]
|
||||
|
||||
# create associated model parameter
|
||||
self.partitions[index] = ModelWeightParameter(
|
||||
data=data, **self.kwargs) # type: ignore[arg-type]
|
||||
self.partitions[index] = ModelWeightParameter(data=data, **self.kwargs) # type: ignore[arg-type]
|
||||
|
||||
# hold local reference, since ModelWeightParameter does not
|
||||
# see https://github.com/pytorch/pytorch/issues/75932
|
||||
@@ -471,8 +488,7 @@ class SharedWeightParameter(BasevLLMParameter):
|
||||
assert len(self.partitions) == 1 and 0 in self.partitions
|
||||
partition = self.partitions[0]
|
||||
|
||||
ModelWeightParameter.load_column_parallel_weight(
|
||||
partition, loaded_weight)
|
||||
ModelWeightParameter.load_column_parallel_weight(partition, loaded_weight)
|
||||
|
||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
assert len(self.partitions) == 1 and 0 in self.partitions
|
||||
@@ -490,10 +506,8 @@ class SharedWeightParameter(BasevLLMParameter):
|
||||
shard_offset = self.tp_rank * shard_size
|
||||
|
||||
ModelWeightParameter.load_merged_column_weight(
|
||||
partition,
|
||||
loaded_weight,
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size)
|
||||
partition, loaded_weight, shard_offset=shard_offset, shard_size=shard_size
|
||||
)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
partition_id = self._shard_id_as_int(kwargs.pop("shard_id"))
|
||||
@@ -517,33 +531,42 @@ class SharedWeightParameter(BasevLLMParameter):
|
||||
def process_weights_after_loading(self):
|
||||
for key in self.partitions:
|
||||
self.partitions[key] = torch.nn.Parameter(
|
||||
data=self.partitions[key].data, requires_grad=False)
|
||||
data=self.partitions[key].data, requires_grad=False
|
||||
)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
raise ValueError("Accessing `data` of a "
|
||||
"`PartitionedModelWeightParameter` is not allowed. "
|
||||
"Instead, use `get_partition` to get the weight of "
|
||||
"the particular partition you want to access")
|
||||
raise ValueError(
|
||||
"Accessing `data` of a "
|
||||
"`PartitionedModelWeightParameter` is not allowed. "
|
||||
"Instead, use `get_partition` to get the weight of "
|
||||
"the particular partition you want to access"
|
||||
)
|
||||
|
||||
def _fake_weight_loader(self, param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_weight_shard_id: Optional[Union[str, int]]):
|
||||
raise ValueError("When loading partition weights of "
|
||||
f"{self.__class__.__name__}, use methods provided by "
|
||||
f"{self.__class__.__name__}, not partition loader")
|
||||
def _fake_weight_loader(
|
||||
self,
|
||||
param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_weight_shard_id: Optional[Union[str, int]],
|
||||
):
|
||||
raise ValueError(
|
||||
"When loading partition weights of "
|
||||
f"{self.__class__.__name__}, use methods provided by "
|
||||
f"{self.__class__.__name__}, not partition loader"
|
||||
)
|
||||
|
||||
|
||||
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
|
||||
output_dim: int, **kwargs) -> BasevLLMParameter:
|
||||
def permute_param_layout_(
|
||||
param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs
|
||||
) -> BasevLLMParameter:
|
||||
"""
|
||||
Permute a parameter's layout to the specified input and output dimensions,
|
||||
Permute a parameter's layout to the specified input and output dimensions,
|
||||
useful for forcing the parameter into a known layout, for example, if I need
|
||||
a packed (quantized) weight matrix to be in the layout
|
||||
a packed (quantized) weight matrix to be in the layout
|
||||
{input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
then I can call:
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
to ensure x is in the correct layout (permuting it to the correct layout if
|
||||
to ensure x is in the correct layout (permuting it to the correct layout if
|
||||
required, asserting if it cannot get it to the correct layout)
|
||||
"""
|
||||
|
||||
@@ -551,35 +574,34 @@ def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
|
||||
curr_output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
if curr_input_dim is None or curr_output_dim is None:
|
||||
assert param.data.dim() == 2,\
|
||||
"permute_param_layout_ only supports 2D parameters when either "\
|
||||
assert param.data.dim() == 2, (
|
||||
"permute_param_layout_ only supports 2D parameters when either "
|
||||
"input_dim or output_dim is not set"
|
||||
)
|
||||
|
||||
# if one of the dimensions is not set, set it to the opposite of the other
|
||||
# we can only do this since we asserted the parameter is 2D above
|
||||
if curr_input_dim is None:
|
||||
assert curr_output_dim is not None,\
|
||||
"either input or output dim must be set"
|
||||
assert curr_output_dim is not None, "either input or output dim must be set"
|
||||
curr_input_dim = (curr_output_dim + 1) % 2
|
||||
if curr_output_dim is None:
|
||||
assert curr_input_dim is not None,\
|
||||
"either input or output dim must be set"
|
||||
assert curr_input_dim is not None, "either input or output dim must be set"
|
||||
curr_output_dim = (curr_input_dim + 1) % 2
|
||||
|
||||
# create permutation from the current layout to the layout with
|
||||
# self.input_dim at input_dim and self.output_dim at output_dim preserving
|
||||
# other dimensions
|
||||
perm = [
|
||||
i for i in range(param.data.dim())
|
||||
if i not in [curr_input_dim, curr_output_dim]
|
||||
i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim]
|
||||
]
|
||||
perm.insert(input_dim, curr_input_dim)
|
||||
perm.insert(output_dim, curr_output_dim)
|
||||
|
||||
if "packed_dim" in kwargs:
|
||||
assert hasattr(param, "packed_dim") and\
|
||||
param.packed_dim == perm[kwargs["packed_dim"]],\
|
||||
"permute_param_layout_ currently doesn't support repacking"
|
||||
assert (
|
||||
hasattr(param, "packed_dim")
|
||||
and param.packed_dim == perm[kwargs["packed_dim"]]
|
||||
), "permute_param_layout_ currently doesn't support repacking"
|
||||
|
||||
param.data = param.data.permute(*perm)
|
||||
if hasattr(param, "_input_dim"):
|
||||
@@ -592,29 +614,30 @@ def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
|
||||
return param
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
|
||||
marlin_tile_size):
|
||||
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size):
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset,
|
||||
bitblas_tile_size):
|
||||
def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset, bitblas_tile_size):
|
||||
return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
|
||||
marlin_tile_size, bitblas_tile_size):
|
||||
def _adjust_shard_indexes_for_packing(
|
||||
shard_size, shard_offset, packed_factor, marlin_tile_size, bitblas_tile_size
|
||||
):
|
||||
shard_size = shard_size // packed_factor
|
||||
shard_offset = shard_offset // packed_factor
|
||||
if marlin_tile_size is not None:
|
||||
return _adjust_shard_indexes_for_marlin(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
marlin_tile_size=marlin_tile_size)
|
||||
marlin_tile_size=marlin_tile_size,
|
||||
)
|
||||
elif bitblas_tile_size is not None:
|
||||
return _adjust_shard_indexes_for_bitblas(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
bitblas_tile_size=bitblas_tile_size)
|
||||
bitblas_tile_size=bitblas_tile_size,
|
||||
)
|
||||
|
||||
return shard_size, shard_offset
|
||||
|
||||
Reference in New Issue
Block a user