[Doc] Add developer guide for CustomOp (#30886)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -181,6 +181,7 @@ def get_masked_input_and_mask(
|
||||
return input_, ~vocab_mask
|
||||
|
||||
|
||||
# --8<-- [start:vocab_parallel_embedding]
|
||||
@CustomOp.register("vocab_parallel_embedding")
|
||||
class VocabParallelEmbedding(CustomOp):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
@@ -221,6 +222,8 @@ class VocabParallelEmbedding(CustomOp):
|
||||
prefix: full name of the layer in the state dict
|
||||
""" # noqa: E501
|
||||
|
||||
# --8<-- [end:vocab_parallel_embedding]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
@@ -492,6 +495,7 @@ class VocabParallelEmbedding(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
# --8<-- [start:parallel_lm_head]
|
||||
@CustomOp.register("parallel_lm_head")
|
||||
class ParallelLMHead(VocabParallelEmbedding):
|
||||
"""Parallelized LM head.
|
||||
@@ -509,6 +513,8 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
padding_size: padding size for the vocabulary.
|
||||
"""
|
||||
|
||||
# --8<-- [end:parallel_lm_head]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
|
||||
Reference in New Issue
Block a user