[Kernel/Quant] Remove AQLM (#22943)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -692,8 +692,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
# Special case for per-tensor scale to load scalar into fused array.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
@@ -781,13 +779,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
if not is_sharded_weight:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
# Special case for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_offset = loaded_shard_id * shard_size
|
||||
param_data = param_data.narrow(0, shard_offset, shard_size)
|
||||
|
||||
# Special case for per-tensor scales in fused case.
|
||||
elif needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
@@ -1081,8 +1072,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
|
||||
# Special case for per-tensor scales in fused case.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
@@ -1204,13 +1193,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
# Special case for for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
||||
param_data = param_data.narrow(0, shard_index * shard_size,
|
||||
shard_size)
|
||||
# Special case for per-tensor scales in fused case.
|
||||
elif needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
|
||||
Reference in New Issue
Block a user