[Misc] Add shard_id validation for MergedColumnLinear (#35055)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -66,15 +66,23 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def adjust_marlin_shard(param, shard_size, shard_offset):
|
def adjust_marlin_shard(
|
||||||
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
param: Parameter,
|
||||||
|
shard_size: int,
|
||||||
|
shard_offset: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
|
||||||
if marlin_tile_size is None:
|
if marlin_tile_size is None:
|
||||||
return shard_size, shard_offset
|
return shard_size, shard_offset
|
||||||
|
|
||||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||||
|
|
||||||
|
|
||||||
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
|
def adjust_block_scale_shard(
|
||||||
|
weight_block_size: tuple[int, ...] | None,
|
||||||
|
shard_size: int,
|
||||||
|
shard_offset: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
assert weight_block_size is not None
|
assert weight_block_size is not None
|
||||||
block_n = weight_block_size[0]
|
block_n = weight_block_size[0]
|
||||||
shard_offset = (shard_offset + block_n - 1) // block_n
|
shard_offset = (shard_offset + block_n - 1) // block_n
|
||||||
@@ -83,7 +91,9 @@ def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
|
|||||||
|
|
||||||
|
|
||||||
def adjust_bitsandbytes_4bit_shard(
|
def adjust_bitsandbytes_4bit_shard(
|
||||||
param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str
|
param: Parameter,
|
||||||
|
shard_offsets: dict[str, tuple[int, int]],
|
||||||
|
loaded_shard_id: str,
|
||||||
) -> tuple[int, int]:
|
) -> tuple[int, int]:
|
||||||
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
||||||
|
|
||||||
@@ -97,7 +107,11 @@ def adjust_bitsandbytes_4bit_shard(
|
|||||||
return quantized_size, quantized_offset
|
return quantized_size, quantized_offset
|
||||||
|
|
||||||
|
|
||||||
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
def adjust_scalar_to_fused_array(
|
||||||
|
param_data: torch.Tensor,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
shard_id: int | str,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""For fused modules (QKV and MLP) we have an array of length
|
"""For fused modules (QKV and MLP) we have an array of length
|
||||||
N that holds 1 scale for each "logical" matrix. So the param
|
N that holds 1 scale for each "logical" matrix. So the param
|
||||||
is an array of length N. The loaded_weight corresponds to
|
is an array of length N. The loaded_weight corresponds to
|
||||||
@@ -117,12 +131,14 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
|||||||
assert loaded_weight.shape[0] == 1
|
assert loaded_weight.shape[0] == 1
|
||||||
loaded_weight = loaded_weight[0]
|
loaded_weight = loaded_weight[0]
|
||||||
|
|
||||||
return param[shard_id], loaded_weight
|
return param_data[shard_id], loaded_weight
|
||||||
|
|
||||||
|
|
||||||
# TODO(Isotr0py): We might need a more flexible structure to handle
|
# TODO(Isotr0py): We might need a more flexible structure to handle
|
||||||
# bitsandbytes shard offsets.
|
# bitsandbytes shard offsets.
|
||||||
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]):
|
def left_shift_bitsandbytes_4bit_shard(
|
||||||
|
bnb_weight_attrs: dict[str, Any],
|
||||||
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Separate the BitsAndBytes 4-bit shard.
|
Separate the BitsAndBytes 4-bit shard.
|
||||||
|
|
||||||
@@ -681,12 +697,41 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
disable_tp=disable_tp,
|
disable_tp=disable_tp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
|
||||||
|
if loaded_shard_id is None:
|
||||||
|
return
|
||||||
|
if isinstance(loaded_shard_id, tuple):
|
||||||
|
for idx in loaded_shard_id:
|
||||||
|
if not (0 <= idx < len(self.output_sizes)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Shard id index {idx} should be between 0 and "
|
||||||
|
f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
|
||||||
|
)
|
||||||
|
if len(loaded_shard_id) > 1 and any(
|
||||||
|
b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Shard id with multiple indices should be consecutive. "
|
||||||
|
f"Got shard id {loaded_shard_id}."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
elif isinstance(loaded_shard_id, int):
|
||||||
|
if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
|
||||||
|
raise ValueError(
|
||||||
|
f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
|
||||||
|
f"Got shard id {loaded_shard_id}."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
raise ValueError("This line should not be reached")
|
||||||
|
|
||||||
def weight_loader(
|
def weight_loader(
|
||||||
self,
|
self,
|
||||||
param: Parameter,
|
param: Parameter,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
loaded_shard_id: tuple[int, ...] | int | None = None,
|
loaded_shard_id: tuple[int, ...] | int | None = None,
|
||||||
):
|
):
|
||||||
|
self.validate_shard_id(loaded_shard_id)
|
||||||
|
# FIXME(Isotr0py): Enable tuple shard_id for BNB quantization.
|
||||||
if isinstance(loaded_shard_id, tuple):
|
if isinstance(loaded_shard_id, tuple):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Shard id with multiple indices is not supported in weight_loader, "
|
"Shard id with multiple indices is not supported in weight_loader, "
|
||||||
@@ -874,6 +919,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
loaded_shard_id: tuple[int, ...] | int | None = None,
|
loaded_shard_id: tuple[int, ...] | int | None = None,
|
||||||
):
|
):
|
||||||
|
self.validate_shard_id(loaded_shard_id)
|
||||||
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
|
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
|
||||||
if isinstance(param, PerTensorScaleParameter):
|
if isinstance(param, PerTensorScaleParameter):
|
||||||
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
|
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
|
||||||
@@ -1005,6 +1051,18 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
disable_tp=disable_tp,
|
disable_tp=disable_tp,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def validate_shard_id(self, loaded_shard_id: str | None):
|
||||||
|
if loaded_shard_id is None:
|
||||||
|
return
|
||||||
|
if isinstance(loaded_shard_id, str):
|
||||||
|
if loaded_shard_id not in ["q", "k", "v"]:
|
||||||
|
raise ValueError(
|
||||||
|
"Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
|
||||||
|
f"got shard id {loaded_shard_id}."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
raise ValueError("This line should not be reached")
|
||||||
|
|
||||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||||
shard_offset_mapping = {
|
shard_offset_mapping = {
|
||||||
"q": 0,
|
"q": 0,
|
||||||
@@ -1073,6 +1131,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
loaded_shard_id: str | None = None,
|
loaded_shard_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
self.validate_shard_id(loaded_shard_id)
|
||||||
if loaded_shard_id is None: # special case for certain models
|
if loaded_shard_id is None: # special case for certain models
|
||||||
if isinstance(param, PerTensorScaleParameter):
|
if isinstance(param, PerTensorScaleParameter):
|
||||||
param.load_qkv_weight(
|
param.load_qkv_weight(
|
||||||
@@ -1112,6 +1171,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
loaded_shard_id: str | None = None,
|
loaded_shard_id: str | None = None,
|
||||||
):
|
):
|
||||||
|
self.validate_shard_id(loaded_shard_id)
|
||||||
# Special case for GGUF
|
# Special case for GGUF
|
||||||
# initialize GGUF param after we know the quantize type
|
# initialize GGUF param after we know the quantize type
|
||||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||||
|
|||||||
Reference in New Issue
Block a user