[ Misc ] fbgemm checkpoints (#6559)
This commit is contained in:
@@ -161,6 +161,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
org_num_embeddings: original vocabulary size (without LoRA).
|
||||
padding_size: padding size for the vocabulary.
|
||||
quant_config: quant config for the layer
|
||||
prefix: full name of the layer in the state dict
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
@@ -169,7 +170,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
# Keep the input dimensions.
|
||||
@@ -195,7 +197,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
|
||||
linear_method = None
|
||||
if quant_config is not None:
|
||||
linear_method = quant_config.get_quant_method(self)
|
||||
linear_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method: QuantizeMethodBase = linear_method
|
||||
@@ -382,9 +384,11 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||
org_num_embeddings, padding_size, quant_config)
|
||||
org_num_embeddings, padding_size, quant_config,
|
||||
prefix)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
|
||||
Reference in New Issue
Block a user