[Doc] Add developer guide for CustomOp (#30886)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -304,6 +304,7 @@ class LinearBase(CustomOp):
|
||||
param.tp_size = self.tp_size
|
||||
|
||||
|
||||
# --8<-- [start:replicated_linear]
|
||||
@CustomOp.register("replicated_linear")
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""Replicated linear layer.
|
||||
@@ -321,6 +322,8 @@ class ReplicatedLinear(LinearBase):
|
||||
disable_tp: Take no effect for replicated linear layers.
|
||||
"""
|
||||
|
||||
# --8<-- [end:replicated_linear]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
@@ -421,6 +424,7 @@ class ReplicatedLinear(LinearBase):
|
||||
return s
|
||||
|
||||
|
||||
# --8<-- [start:column_parallel_linear]
|
||||
@CustomOp.register("column_parallel_linear")
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""Linear layer with column parallelism.
|
||||
@@ -448,6 +452,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
||||
"""
|
||||
|
||||
# --8<-- [end:column_parallel_linear]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
@@ -1289,6 +1295,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
# --8<-- [start:row_parallel_linear]
|
||||
@CustomOp.register("row_parallel_linear")
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""Linear layer with row parallelism.
|
||||
@@ -1323,6 +1330,8 @@ class RowParallelLinear(LinearBase):
|
||||
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
||||
"""
|
||||
|
||||
# --8<-- [end:row_parallel_linear]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
|
||||
Reference in New Issue
Block a user