[ Misc ] fbgemm checkpoints (#6559)

This commit is contained in:
Robert Shaw
2024-07-20 12:36:57 -04:00
committed by GitHub
parent 9042d68362
commit 683e3cb9c4
24 changed files with 234 additions and 47 deletions

View File

@@ -161,6 +161,7 @@ class VocabParallelEmbedding(torch.nn.Module):
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
prefix: full name of the layer in the state dict
""" # noqa: E501
def __init__(self,
@@ -169,7 +170,8 @@ class VocabParallelEmbedding(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
# Keep the input dimensions.
@@ -195,7 +197,7 @@ class VocabParallelEmbedding(torch.nn.Module):
linear_method = None
if quant_config is not None:
linear_method = quant_config.get_quant_method(self)
linear_method = quant_config.get_quant_method(self, prefix=prefix)
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method: QuantizeMethodBase = linear_method
@@ -382,9 +384,11 @@ class ParallelLMHead(VocabParallelEmbedding):
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config)
org_num_embeddings, padding_size, quant_config,
prefix)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,