[Doc] Add developer guide for CustomOp (#30886)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -88,6 +88,7 @@ def dispatch_rocm_rmsnorm_func(
|
||||
return rms_norm
|
||||
|
||||
|
||||
# --8<-- [start:rms_norm]
|
||||
@CustomOp.register("rms_norm")
|
||||
class RMSNorm(CustomOp):
|
||||
"""Root mean square normalization.
|
||||
@@ -96,6 +97,8 @@ class RMSNorm(CustomOp):
|
||||
Refer to https://arxiv.org/abs/1910.07467
|
||||
"""
|
||||
|
||||
# --8<-- [end:rms_norm]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -253,6 +256,7 @@ class RMSNorm(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
# --8<-- [start:gemma_rms_norm]
|
||||
@CustomOp.register("gemma_rms_norm")
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
"""RMS normalization for Gemma.
|
||||
@@ -262,6 +266,8 @@ class GemmaRMSNorm(CustomOp):
|
||||
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
|
||||
"""
|
||||
|
||||
# --8<-- [end:gemma_rms_norm]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -321,6 +327,7 @@ class GemmaRMSNorm(CustomOp):
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
# --8<-- [start:rms_norm_gated]
|
||||
@CustomOp.register("rms_norm_gated")
|
||||
class RMSNormGated(CustomOp):
|
||||
"""RMS Normalization with optional gating.
|
||||
@@ -331,6 +338,8 @@ class RMSNormGated(CustomOp):
|
||||
- Optional gating with SiLU activation
|
||||
"""
|
||||
|
||||
# --8<-- [end:rms_norm_gated]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
|
||||
Reference in New Issue
Block a user