[Misc] Add shard_id validation for MergedColumnLinear (#35055)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-02-25 14:12:28 +08:00
committed by GitHub
parent 4572a06afe
commit a6c137521c

View File

@@ -66,15 +66,23 @@ WEIGHT_LOADER_V2_SUPPORTED = [
] ]
def adjust_marlin_shard(param, shard_size, shard_offset): def adjust_marlin_shard(
marlin_tile_size = getattr(param, "marlin_tile_size", None) param: Parameter,
shard_size: int,
shard_offset: int,
) -> tuple[int, int]:
marlin_tile_size: int | None = getattr(param, "marlin_tile_size", None)
if marlin_tile_size is None: if marlin_tile_size is None:
return shard_size, shard_offset return shard_size, shard_offset
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset): def adjust_block_scale_shard(
weight_block_size: tuple[int, ...] | None,
shard_size: int,
shard_offset: int,
) -> tuple[int, int]:
assert weight_block_size is not None assert weight_block_size is not None
block_n = weight_block_size[0] block_n = weight_block_size[0]
shard_offset = (shard_offset + block_n - 1) // block_n shard_offset = (shard_offset + block_n - 1) // block_n
@@ -83,7 +91,9 @@ def adjust_block_scale_shard(weight_block_size, shard_size, shard_offset):
def adjust_bitsandbytes_4bit_shard( def adjust_bitsandbytes_4bit_shard(
param: Parameter, shard_offsets: dict[str, tuple[int, int]], loaded_shard_id: str param: Parameter,
shard_offsets: dict[str, tuple[int, int]],
loaded_shard_id: str,
) -> tuple[int, int]: ) -> tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
@@ -97,7 +107,11 @@ def adjust_bitsandbytes_4bit_shard(
return quantized_size, quantized_offset return quantized_size, quantized_offset
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): def adjust_scalar_to_fused_array(
param_data: torch.Tensor,
loaded_weight: torch.Tensor,
shard_id: int | str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""For fused modules (QKV and MLP) we have an array of length """For fused modules (QKV and MLP) we have an array of length
N that holds 1 scale for each "logical" matrix. So the param N that holds 1 scale for each "logical" matrix. So the param
is an array of length N. The loaded_weight corresponds to is an array of length N. The loaded_weight corresponds to
@@ -117,12 +131,14 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
assert loaded_weight.shape[0] == 1 assert loaded_weight.shape[0] == 1
loaded_weight = loaded_weight[0] loaded_weight = loaded_weight[0]
return param[shard_id], loaded_weight return param_data[shard_id], loaded_weight
# TODO(Isotr0py): We might need a more flexible structure to handle # TODO(Isotr0py): We might need a more flexible structure to handle
# bitsandbytes shard offsets. # bitsandbytes shard offsets.
def left_shift_bitsandbytes_4bit_shard(bnb_weight_attrs: dict[str, Any]): def left_shift_bitsandbytes_4bit_shard(
bnb_weight_attrs: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
""" """
Separate the BitsAndBytes 4-bit shard. Separate the BitsAndBytes 4-bit shard.
@@ -681,12 +697,41 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
disable_tp=disable_tp, disable_tp=disable_tp,
) )
def validate_shard_id(self, loaded_shard_id: int | tuple[int, ...] | None):
if loaded_shard_id is None:
return
if isinstance(loaded_shard_id, tuple):
for idx in loaded_shard_id:
if not (0 <= idx < len(self.output_sizes)):
raise ValueError(
f"Shard id index {idx} should be between 0 and "
f"{len(self.output_sizes) - 1}. Got shard id {loaded_shard_id}."
)
if len(loaded_shard_id) > 1 and any(
b - a != 1 for a, b in zip(loaded_shard_id[:-1], loaded_shard_id[1:])
):
raise ValueError(
"Shard id with multiple indices should be consecutive. "
f"Got shard id {loaded_shard_id}."
)
return
elif isinstance(loaded_shard_id, int):
if loaded_shard_id < 0 or loaded_shard_id >= len(self.output_sizes):
raise ValueError(
f"Shard id should be between 0 and {len(self.output_sizes) - 1}. "
f"Got shard id {loaded_shard_id}."
)
return
raise ValueError("This line should not be reached")
def weight_loader( def weight_loader(
self, self,
param: Parameter, param: Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: tuple[int, ...] | int | None = None, loaded_shard_id: tuple[int, ...] | int | None = None,
): ):
self.validate_shard_id(loaded_shard_id)
# FIXME(Isotr0py): Enable tuple shard_id for BNB quantization.
if isinstance(loaded_shard_id, tuple): if isinstance(loaded_shard_id, tuple):
raise NotImplementedError( raise NotImplementedError(
"Shard id with multiple indices is not supported in weight_loader, " "Shard id with multiple indices is not supported in weight_loader, "
@@ -874,6 +919,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: tuple[int, ...] | int | None = None, loaded_shard_id: tuple[int, ...] | int | None = None,
): ):
self.validate_shard_id(loaded_shard_id)
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple): if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
if isinstance(param, PerTensorScaleParameter): if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
@@ -1005,6 +1051,18 @@ class QKVParallelLinear(ColumnParallelLinear):
disable_tp=disable_tp, disable_tp=disable_tp,
) )
def validate_shard_id(self, loaded_shard_id: str | None):
if loaded_shard_id is None:
return
if isinstance(loaded_shard_id, str):
if loaded_shard_id not in ["q", "k", "v"]:
raise ValueError(
"Shard id for QKVParallelLinear should be 'q', 'k', or 'v', "
f"got shard id {loaded_shard_id}."
)
return
raise ValueError("This line should not be reached")
def _get_shard_offset_mapping(self, loaded_shard_id: str): def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = { shard_offset_mapping = {
"q": 0, "q": 0,
@@ -1073,6 +1131,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: str | None = None, loaded_shard_id: str | None = None,
): ):
self.validate_shard_id(loaded_shard_id)
if loaded_shard_id is None: # special case for certain models if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter): if isinstance(param, PerTensorScaleParameter):
param.load_qkv_weight( param.load_qkv_weight(
@@ -1112,6 +1171,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: str | None = None, loaded_shard_id: str | None = None,
): ):
self.validate_shard_id(loaded_shard_id)
# Special case for GGUF # Special case for GGUF
# initialize GGUF param after we know the quantize type # initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight = getattr(param, "is_gguf_weight", False)