[Doc] Add developer guide for CustomOp (#30886)

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2026-01-10 00:21:11 +08:00
committed by GitHub
parent ac9f9330e6
commit 08d954f036
24 changed files with 441 additions and 5 deletions

View File

@@ -304,6 +304,7 @@ class LinearBase(CustomOp):
param.tp_size = self.tp_size
# --8<-- [start:replicated_linear]
@CustomOp.register("replicated_linear")
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
@@ -321,6 +322,8 @@ class ReplicatedLinear(LinearBase):
disable_tp: Take no effect for replicated linear layers.
"""
# --8<-- [end:replicated_linear]
def __init__(
self,
input_size: int,
@@ -421,6 +424,7 @@ class ReplicatedLinear(LinearBase):
return s
# --8<-- [start:column_parallel_linear]
@CustomOp.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
@@ -448,6 +452,8 @@ class ColumnParallelLinear(LinearBase):
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
# --8<-- [end:column_parallel_linear]
def __init__(
self,
input_size: int,
@@ -1289,6 +1295,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight)
# --8<-- [start:row_parallel_linear]
@CustomOp.register("row_parallel_linear")
class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
@@ -1323,6 +1330,8 @@ class RowParallelLinear(LinearBase):
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
# --8<-- [end:row_parallel_linear]
def __init__(
self,
input_size: int,