[ Misc ] Refactor w8a8 to use process_weights_after_load (Simplify Weight Loading) (#5940)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
@@ -41,6 +41,29 @@ def adjust_bitsandbytes_shard(param: Parameter,
|
||||
return quantized_size, quantized_offset
|
||||
|
||||
|
||||
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
||||
"""For fused modules (QKV and MLP) we have an array of length
|
||||
N that holds 1 scale for each "logical" matrix. So the param
|
||||
is an array of length N. The loaded_weight corresponds to
|
||||
one of the shards on disk. Here, we slice the param based on
|
||||
the shard_id for loading.
|
||||
"""
|
||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
if isinstance(shard_id, str):
|
||||
shard_id = qkv_idxs[shard_id]
|
||||
elif not isinstance(shard_id, int):
|
||||
raise ValueError(f"Unknown Shard Id {shard_id}")
|
||||
|
||||
# AutoFP8 scales do not have a shape
|
||||
# compressed-tensors scales do have a shape
|
||||
if len(loaded_weight.shape) != 0:
|
||||
assert loaded_weight.shape[0] == 1
|
||||
loaded_weight = loaded_weight[0]
|
||||
|
||||
return param[shard_id], loaded_weight
|
||||
|
||||
|
||||
class LinearMethodBase(QuantizeMethodBase):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@@ -358,37 +381,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
|
||||
param_shard_splitter = getattr(param, "shard_splitter", None)
|
||||
|
||||
if output_dim is not None and param_shard_splitter is not None:
|
||||
raise NotImplementedError(
|
||||
"We do not currently support output_dim != None and "
|
||||
"shard_splitter != None for a parameter. Please open an issue."
|
||||
)
|
||||
# If a parameter has defined a shard_splitter to be used for
|
||||
# the weight, it should be applied before the weight is
|
||||
# loaded/copied to the parameter. The shard_splitter applies
|
||||
# logic by using the loaded_shard_id to ensure that the loaded
|
||||
# param is loaded to the correct location
|
||||
# within the parameter defined by the linear method.
|
||||
if loaded_shard_id is None and param_shard_splitter is not None:
|
||||
raise NotImplementedError(
|
||||
"We do not currently support loaded_shard_id == None and "
|
||||
"shard_splitter != None for a parameter. Please open an issue."
|
||||
)
|
||||
|
||||
# Special case for Fp8 scales.
|
||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||
None)
|
||||
# Special case for per-tensor scale to load scalar into fused array.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk (qkv/mlp).
|
||||
if output_dim is None:
|
||||
# If fp8 + scale, need to send to each shard.
|
||||
if fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
if needs_scalar_to_array is not None:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, 0)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
@@ -450,15 +451,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = loaded_shard_id * shard_size
|
||||
param_data = param_data.narrow(0, shard_offset, shard_size)
|
||||
|
||||
# If a param_shard_splitter is defined by the LinearMethod, use it.
|
||||
elif param_shard_splitter is not None:
|
||||
logical_widths = getattr(param, "logical_widths", None)
|
||||
param_data, loaded_weight = param_shard_splitter(
|
||||
param_data, loaded_weight, loaded_shard_id, logical_widths)
|
||||
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
# Special case for per-tensor scales in fused case.
|
||||
elif needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
|
||||
else:
|
||||
@@ -548,36 +543,15 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
|
||||
param_shard_splitter = getattr(param, "shard_splitter", None)
|
||||
|
||||
if output_dim is not None and param_shard_splitter is not None:
|
||||
raise NotImplementedError(
|
||||
"We do not currently support output_dim != None and "
|
||||
"shard_splitter != None for a parameter. Please open an issue."
|
||||
)
|
||||
# If a parameter has defined a shard_splitter to be used for
|
||||
# the weight, it should be applied before the weight is
|
||||
# loaded/copied to the parameter. The shard_splitter applies
|
||||
# logic by using the loaded_shard_id to ensure that the loaded
|
||||
# param is loaded to the correct location
|
||||
# within the parameter defined by the linear method.
|
||||
if loaded_shard_id is None and param_shard_splitter is not None:
|
||||
raise NotImplementedError(
|
||||
"We do not currently support loaded_shard_id == None and "
|
||||
"shard_splitter != None for a parameter. Please open an issue."
|
||||
)
|
||||
|
||||
# Special case for Fp8 scales.
|
||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||
None)
|
||||
# Special case for per-tensor scales in fused case.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk (qkv/mlp).
|
||||
if output_dim is None:
|
||||
# If fp8 + scale, need to send to each shard.
|
||||
if fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
if needs_scalar_to_array is not None:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, 0)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
@@ -667,15 +641,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
||||
param_data = param_data.narrow(0, shard_index * shard_size,
|
||||
shard_size)
|
||||
# If a param_shard_splitter is defined by the LinearMethod, use it.
|
||||
elif param_shard_splitter is not None:
|
||||
logical_widths = getattr(param, "logical_widths", None)
|
||||
param_data, loaded_weight = param_shard_splitter(
|
||||
param_data, loaded_weight, loaded_shard_id, logical_widths)
|
||||
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
# Special case for per-tensor scales in fused case.
|
||||
elif needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
|
||||
Reference in New Issue
Block a user