[Core] Support tensor parallelism for GGUF quantization (#7520)
This commit is contained in:
@@ -507,11 +507,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_shard_id
|
||||
|
||||
if is_gguf_weight:
|
||||
shard_size = loaded_weight.shape[output_dim]
|
||||
shard_offset = loaded_weight.shape[output_dim] * \
|
||||
loaded_shard_id
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_shape = list(loaded_weight.shape)
|
||||
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_size[loaded_shard_id] = loaded_weight.shape
|
||||
param.shard_size[loaded_shard_id] = shard_shape
|
||||
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
input_size = loaded_weight.shape[input_dim]
|
||||
param_data = param_data.narrow(input_dim, 0, input_size)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
@@ -863,8 +868,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param, orig_qkv_offsets, loaded_shard_id)
|
||||
|
||||
if is_gguf_weight:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_shape = list(loaded_weight.shape)
|
||||
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_size[loaded_shard_id] = loaded_weight.shape
|
||||
param.shard_size[loaded_shard_id] = shard_shape
|
||||
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
input_size = loaded_weight.shape[input_dim]
|
||||
param_data = param_data.narrow(input_dim, 0, input_size)
|
||||
@@ -976,6 +986,7 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
|
||||
# Special case for GGUF
|
||||
@@ -986,7 +997,10 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
weight_shape = list(loaded_weight.shape)
|
||||
if input_dim:
|
||||
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
|
||||
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if input_dim is not None:
|
||||
|
||||
Reference in New Issue
Block a user