[Core] Refactor GGUF parameters packing and forwarding (#8859)
This commit is contained in:
@@ -440,17 +440,23 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
return
|
||||
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
from gguf.constants import GGML_QUANT_SIZES
|
||||
if is_gguf_weight:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
ori_shape = param.tensor_shape
|
||||
weight_types = self.qweight_type.shard_weight_type.values()
|
||||
row_size = []
|
||||
for weight_type in weight_types:
|
||||
block_size, type_size = GGML_QUANT_SIZES[weight_type]
|
||||
row_size.append(ori_shape[1] // block_size * type_size)
|
||||
q_shape = (ori_shape[0], max(row_size))
|
||||
param.materialize(q_shape, dtype=loaded_weight.dtype)
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||
param.data_container.append(loaded_weight)
|
||||
if len(param.data_container) == 2:
|
||||
self.qweight = param.materialize_nested()
|
||||
return
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
@@ -515,18 +521,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = loaded_weight.shape[output_dim] * \
|
||||
loaded_shard_id
|
||||
|
||||
if is_gguf_weight:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_shape = list(loaded_weight.shape)
|
||||
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_size[loaded_shard_id] = shard_shape
|
||||
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
input_size = loaded_weight.shape[input_dim]
|
||||
param_data = param_data.narrow(input_dim, 0, input_size)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
@@ -783,17 +777,23 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
return
|
||||
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
from gguf.constants import GGML_QUANT_SIZES
|
||||
if is_gguf_weight:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
ori_shape = param.tensor_shape
|
||||
weight_types = self.qweight_type.shard_weight_type.values()
|
||||
row_size = []
|
||||
for weight_type in weight_types:
|
||||
block_size, type_size = GGML_QUANT_SIZES[weight_type]
|
||||
row_size.append(ori_shape[1] // block_size * type_size)
|
||||
q_shape = (ori_shape[0], max(row_size))
|
||||
param.materialize(q_shape, dtype=loaded_weight.dtype)
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||
start_idx = tp_rank * shard_size
|
||||
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||
param.data_container.append(loaded_weight)
|
||||
if len(param.data_container) == 3:
|
||||
self.qweight = param.materialize_nested()
|
||||
return
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
@@ -883,18 +883,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||
param, orig_qkv_offsets, loaded_shard_id)
|
||||
|
||||
if is_gguf_weight:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
shard_shape = list(loaded_weight.shape)
|
||||
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_size[loaded_shard_id] = shard_shape
|
||||
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
input_size = loaded_weight.shape[input_dim]
|
||||
param_data = param_data.narrow(input_dim, 0, input_size)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
if loaded_shard_id == "q":
|
||||
|
||||
Reference in New Issue
Block a user