[Core] Refactor GGUF parameters packing and forwarding (#8859)
This commit is contained in:
@@ -86,15 +86,16 @@ class GGUFLinearMethod(LinearMethodBase):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
tensor_shape = (output_size_per_partition, input_size_per_partition)
|
||||
qweight = UninitializedParameter(requires_grad=False)
|
||||
qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"tensor_shape": tensor_shape,
|
||||
"is_gguf_weight": True,
|
||||
"shard_size": {},
|
||||
"data_container": [],
|
||||
"shard_id": [],
|
||||
"shard_id_map": {},
|
||||
})
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("qweight", qweight)
|
||||
@@ -116,21 +117,17 @@ class GGUFLinearMethod(LinearMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
shard_size = getattr(layer.qweight, "shard_size", None)
|
||||
shard_id = getattr(layer.qweight, "shard_id", None)
|
||||
|
||||
if shard_id and shard_size:
|
||||
result = []
|
||||
offset = 0
|
||||
if shard_id:
|
||||
# dequantize shard weights respectively
|
||||
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
|
||||
qweight = layer.qweight.unbind(0)
|
||||
result = []
|
||||
for id in shard_id:
|
||||
shard_weight = layer.qweight[
|
||||
offset:offset +
|
||||
shard_size[id][0], :shard_size[id][1]].contiguous()
|
||||
q_idx = layer.qweight.shard_id_map[id]
|
||||
qweight_type = layer.qweight_type.shard_weight_type[id]
|
||||
result.append(_fuse_mul_mat(x, shard_weight, qweight_type))
|
||||
offset += shard_size[id][0]
|
||||
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
|
||||
out = torch.cat(result, axis=1)
|
||||
else:
|
||||
qweight = layer.qweight
|
||||
@@ -162,3 +159,20 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
|
||||
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
|
||||
x_flat.shape[0])
|
||||
return dequant.view(*x.shape, hidden_size)
|
||||
|
||||
|
||||
class GGUFUninitializedParameter(UninitializedParameter):
|
||||
cls_to_become = Parameter
|
||||
data_container: List[torch.Tensor]
|
||||
|
||||
def materialize_nested(self) -> Parameter:
|
||||
nested_data = torch.nested.nested_tensor(self.data_container,
|
||||
device=self.device,
|
||||
dtype=torch.uint8)
|
||||
self.data_container.clear()
|
||||
param = torch.Tensor._make_subclass(self.cls_to_become,
|
||||
nested_data,
|
||||
require_grad=False)
|
||||
for k, v in self.__dict__.items():
|
||||
setattr(param, k, v)
|
||||
return param
|
||||
|
||||
Reference in New Issue
Block a user