[Doc] Add developer guide for CustomOp (#30886)

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2026-01-10 00:21:11 +08:00
committed by GitHub
parent ac9f9330e6
commit 08d954f036
24 changed files with 441 additions and 5 deletions

View File

@@ -88,6 +88,7 @@ def dispatch_rocm_rmsnorm_func(
return rms_norm
# --8<-- [start:rms_norm]
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
"""Root mean square normalization.
@@ -96,6 +97,8 @@ class RMSNorm(CustomOp):
Refer to https://arxiv.org/abs/1910.07467
"""
# --8<-- [end:rms_norm]
def __init__(
self,
hidden_size: int,
@@ -253,6 +256,7 @@ class RMSNorm(CustomOp):
return s
# --8<-- [start:gemma_rms_norm]
@CustomOp.register("gemma_rms_norm")
class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma.
@@ -262,6 +266,8 @@ class GemmaRMSNorm(CustomOp):
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
"""
# --8<-- [end:gemma_rms_norm]
def __init__(
self,
hidden_size: int,
@@ -321,6 +327,7 @@ class GemmaRMSNorm(CustomOp):
return self.forward_native(x, residual)
# --8<-- [start:rms_norm_gated]
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
"""RMS Normalization with optional gating.
@@ -331,6 +338,8 @@ class RMSNormGated(CustomOp):
- Optional gating with SiLU activation
"""
# --8<-- [end:rms_norm_gated]
def __init__(
self,
hidden_size: int,