[CORE] Quantized lm-head Framework (#4442)

Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: ZX <zx@lbx.dev>
This commit is contained in:
Qubitium-ModelCloud
2024-07-03 06:25:17 +08:00
committed by GitHub
parent 7c008c51a9
commit ee93f4f92a
48 changed files with 268 additions and 121 deletions

View File

@@ -8,6 +8,9 @@ from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
@@ -157,6 +160,7 @@ class VocabParallelEmbedding(torch.nn.Module):
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
""" # noqa: E501
def __init__(self,
@@ -164,7 +168,8 @@ class VocabParallelEmbedding(torch.nn.Module):
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
# Keep the input dimensions.
@@ -187,6 +192,14 @@ class VocabParallelEmbedding(torch.nn.Module):
self.org_vocab_size, tp_rank,
self.tp_size)
self.embedding_dim = embedding_dim
linear_method = None
if quant_config is not None:
linear_method = quant_config.get_quant_method(self)
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method: QuantizeMethodBase = linear_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the vocaburaly dimension.
@@ -201,14 +214,14 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
dtype=params_dtype))
set_weight_attrs(self.weight, {
"parallel_dim": 0,
"weight_loader": self.weight_loader
})
self.linear_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_per_partition],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
@classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
@@ -288,10 +301,32 @@ class VocabParallelEmbedding(torch.nn.Module):
return ret
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
loaded_weight = loaded_weight[self.shard_indices.org_vocab_start_index:
self.shard_indices.org_vocab_end_index]
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)
# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
if output_dim is None:
assert param.data.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
return
# Shard indexes for loading the weight
start_idx = self.shard_indices.org_vocab_start_index
shard_size = self.shard_indices.org_vocab_end_index - start_idx
# If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim:
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
param.pack_factor)
start_idx = start_idx // param.pack_factor
shard_size = shard_size // param.pack_factor
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size
# Copy the data.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0)
@@ -346,16 +381,17 @@ class ParallelLMHead(VocabParallelEmbedding):
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size)
org_num_embeddings, padding_size, quant_config)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"parallel_dim": 0,
"weight_loader": self.weight_loader
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)