[Core] Support loading GGUF model (#5191)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -311,6 +311,17 @@ class ColumnParallelLinear(LinearBase):
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
# Special case for GGUF
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if output_dim is not None:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
@@ -398,6 +409,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None):
|
||||
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.data[loaded_shard_id].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
return
|
||||
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
from gguf.constants import GGML_QUANT_SIZES
|
||||
|
||||
ori_shape = param.tensor_shape
|
||||
weight_types = self.qweight_type.shard_weight_type.values()
|
||||
row_size = []
|
||||
for weight_type in weight_types:
|
||||
block_size, type_size = GGML_QUANT_SIZES[weight_type]
|
||||
row_size.append(ori_shape[1] // block_size * type_size)
|
||||
q_shape = (ori_shape[0], max(row_size))
|
||||
param.materialize(q_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
@@ -460,6 +492,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = loaded_weight.shape[output_dim] * \
|
||||
loaded_shard_id
|
||||
|
||||
if is_gguf_weight:
|
||||
shard_size = loaded_weight.shape[output_dim]
|
||||
shard_offset = loaded_weight.shape[output_dim] * \
|
||||
loaded_shard_id
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_size[loaded_shard_id] = loaded_weight.shape
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
@@ -563,6 +602,29 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
|
||||
# Special case for GGUF
|
||||
# initialize GGUF param after we know the quantize type
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type and loaded_shard_id is not None:
|
||||
idx_map = {"q": 0, "k": 1, "v": 2}
|
||||
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
|
||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||
return
|
||||
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
from gguf.constants import GGML_QUANT_SIZES
|
||||
|
||||
ori_shape = param.tensor_shape
|
||||
weight_types = self.qweight_type.shard_weight_type.values()
|
||||
row_size = []
|
||||
for weight_type in weight_types:
|
||||
block_size, type_size = GGML_QUANT_SIZES[weight_type]
|
||||
row_size.append(ori_shape[1] // block_size * type_size)
|
||||
q_shape = (ori_shape[0], max(row_size))
|
||||
param.materialize(q_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
@@ -650,6 +712,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size, shard_offset = adjust_bitsandbytes_shard(
|
||||
param, orig_qkv_offsets, loaded_shard_id)
|
||||
|
||||
if is_gguf_weight:
|
||||
param.shard_id.append(loaded_shard_id)
|
||||
param.shard_size[loaded_shard_id] = loaded_weight.shape
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
input_size = loaded_weight.shape[input_dim]
|
||||
param_data = param_data.narrow(input_dim, 0, input_size)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
if loaded_shard_id == "q":
|
||||
@@ -755,6 +824,17 @@ class RowParallelLinear(LinearBase):
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
|
||||
# Special case for GGUF
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
if is_gguf_weight_type:
|
||||
param.weight_type = loaded_weight.item()
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if input_dim is not None:
|
||||
shard_size = param_data.shape[input_dim]
|
||||
|
||||
Reference in New Issue
Block a user