[Misc] Add shard_id validation for MergedColumnLinear (#35055)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -66,15 +66,23 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
]
|
||||
|
||||
|
||||
def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
||||
def adjust_marlin_shard(
|
||||
param: Parameter,
|
||||
shard_size: int,
|
||||
shard_offset: int,
|
||||
) -> tuple[int, int]:
|
||||
marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
|
||||
if marlin_tile_size is None:
|
||||
return shard_size, shard_offset
|
||||
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
|
||||
def adjust_block_scale_shard(
|
||||
weight_block_size: tuple[int, ...] | None,
|
||||
shard_size: int,
|
||||
shard_offset: int,
|
||||
) -> tuple[int, int]:
|
||||
assert weight_block_size is not None
|
||||
block_n = weight_block_size[0]
|
||||
shard_offset = (shard_offset + block_n - 1) // block_n
|
||||
@@ -83,7 +91,9 @@ def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
|
||||
|
||||
|
||||
def adjust_bitsandbytes_4bit_shard(
|
||||
param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
|
||||
param: Parameter,
|
||||
shard_offsets: dict[str, tuple[int, int]],
|
||||
loaded_shard_id: str,
|
||||
) -> tuple[int, int]:
|
||||
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
||||
|
||||
@@ -97,7 +107,11 @@ def adjust_bitsandbytes_4bit_shard(
|
||||
return quantized_size, quantized_offset
|
||||
|
||||
|
||||
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
||||
def adjust_scalar_to_fused_array(
|
||||
param_data: torch.Tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
shard_id: int | str,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""For fused modules (QKV and MLP) we have an array of length
|
||||
N that holds 1 scale for each "logical" matrix. So the param
|
||||
is an array of length N. The loaded_weight corresponds to
|
||||
@@ -117,12 +131,14 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
||||
assert loaded_weight.shape[0] == 1
|
||||
loaded_weight = loaded_weight[0]
|
||||
|
||||
return param[shard_id], loaded_weight
|
||||
return param_data[shard_id], loaded_weight
|
||||
|
||||
|
||||
# TODO(Isotr0py): We might need a more flexible structure to handle
|
||||
# bitsandbytes shard offsets.
|
||||
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
|
||||
def left_shift_bitsandbytes_4bit_shard(
|
||||
bnb_weight_attrs: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""
|
||||
Separate the BitsAndBytes 4-bit shard.
|
||||
|
||||
@@ -681,12 +697,41 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
disable_tp=disable_tp,
|
||||
)
|
||||
|
||||
def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
|
||||
if loaded_shard_id is None:
|
||||
return
|
||||
if isinstance(loaded_shard_id, tuple):
|
||||
for idx in loaded_shard_id:
|
||||
if not (0 <= idx < len(self.output_sizes)):
|
||||
raise ValueError(
|
||||
f"Shard id index {idx} should be between 0 and "
|
||||
f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
|
||||
)
|
||||
if len(loaded_shard_id) > 1 and any(
|
||||
b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
|
||||
):
|
||||
raise ValueError(
|
||||
"Shard id with multiple indices should be consecutive. "
|
||||
f"Got shard id {loaded_shard_id}."
|
||||
)
|
||||
return
|
||||
elif isinstance(loaded_shard_id, int):
|
||||
if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
|
||||
raise ValueError(
|
||||
f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
|
||||
f"Got shard id {loaded_shard_id}."
|
||||
)
|
||||
return
|
||||
raise ValueError("This line should not be reached")
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: tuple[int, ...] | int | None = None,
|
||||
):
|
||||
self.validate_shard_id(loaded_shard_id)
|
||||
# FIXME(Isotr0py): Enable tuple shard_id for BNB quantization.
|
||||
if isinstance(loaded_shard_id, tuple):
|
||||
raise NotImplementedError(
|
||||
"Shard id with multiple indices is not supported in weight_loader, "
|
||||
@@ -874,6 +919,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: tuple[int, ...] | int | None = None,
|
||||
):
|
||||
self.validate_shard_id(loaded_shard_id)
|
||||
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
|
||||
@@ -1005,6 +1051,18 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
disable_tp=disable_tp,
|
||||
)
|
||||
|
||||
def validate_shard_id(self, loaded_shard_id: str | None):
|
||||
if loaded_shard_id is None:
|
||||
return
|
||||
if isinstance(loaded_shard_id, str):
|
||||
if loaded_shard_id not in ["q", "k", "v"]:
|
||||
raise ValueError(
|
||||
"Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
|
||||
f"got shard id {loaded_shard_id}."
|
||||
)
|
||||
return
|
||||
raise ValueError("This line should not be reached")
|
||||
|
||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||
shard_offset_mapping = {
|
||||
"q": 0,
|
||||
@@ -1073,6 +1131,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: str | None = None,
|
||||
):
|
||||
self.validate_shard_id(loaded_shard_id)
|
||||
if loaded_shard_id is None: # special case for certain models
|
||||
if isinstance(param, PerTensorScaleParameter):
|
||||
param.load_qkv_weight(
|
||||
@@ -1112,6 +1171,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: str | None = None,
|
||||
):
|
||||
self.validate_shard_id(loaded_shard_id)
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
|
||||
Reference in New Issue
Block a user