[torch.compile] Fine-grained CustomOp enabling mechanism (#9300)
This commit is contained in:
@@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
|
||||
@CustomOp.register("rms_norm")
|
||||
class RMSNorm(CustomOp):
|
||||
"""Root mean square normalization.
|
||||
|
||||
@@ -122,6 +123,7 @@ class RMSNorm(CustomOp):
|
||||
return s
|
||||
|
||||
|
||||
@CustomOp.register("gemma_rms_norm")
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
"""RMS normalization for Gemma.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user