[Core] Support loading GGUF model (#5191)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -3,19 +3,46 @@ from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||
|
||||
|
||||
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
"""Unquantized method for embeddings."""
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for embedding layer."""
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return F.linear(x, layer.weight, bias)
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
def pad_vocab_size(vocab_size: int,
|
||||
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
||||
"""Pad the vocab size to the given value."""
|
||||
@@ -199,7 +226,19 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
if quant_config is not None:
|
||||
linear_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
linear_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
||||
linear_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(linear_method))
|
||||
if is_embedding_layer and not linear_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(linear_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
|
||||
self.linear_method: QuantizeMethodBase = linear_method
|
||||
|
||||
if params_dtype is None:
|
||||
@@ -306,6 +345,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
|
||||
# If the parameter is a gguf weight, then load it directly.
|
||||
if getattr(param, "is_gguf_weight_type", None):
|
||||
param.data.copy_(loaded_weight)
|
||||
param.weight_type = loaded_weight.item()
|
||||
return
|
||||
elif isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
# If parameter does not have output dim, then it should
|
||||
# be copied onto all gpus (e.g. g_idx for act_order gptq).
|
||||
if output_dim is None:
|
||||
@@ -344,7 +391,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input.long(), self.weight)
|
||||
output_parallel = self.linear_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
@@ -389,6 +437,7 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size, quant_config,
|
||||
prefix)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
|
||||
Reference in New Issue
Block a user