[CORE] [QUANT] Support for GPTQModel's dynamic quantization per module override/control (#7086)
This commit is contained in:
committed by
GitHub
parent
2c2b560f48
commit
36a08630e8
@@ -226,24 +226,24 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.tp_size)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
linear_method = None
|
||||
quant_method = None
|
||||
if quant_config is not None:
|
||||
linear_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedEmbeddingMethod()
|
||||
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
||||
linear_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(linear_method))
|
||||
if is_embedding_layer and not linear_method_implements_embedding:
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(quant_method))
|
||||
if is_embedding_layer and not quant_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(linear_method).__name__} must implement "
|
||||
f"The class {type(quant_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
|
||||
self.linear_method: QuantizeMethodBase = linear_method
|
||||
self.quant_method: QuantizeMethodBase = quant_method
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@@ -260,13 +260,13 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
|
||||
self.linear_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
self.quant_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
@classmethod
|
||||
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
||||
@@ -412,8 +412,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.linear_method.embedding(self,
|
||||
masked_input.long())
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
|
||||
Reference in New Issue
Block a user