[torch.compile] Fine-grained CustomOp enabling mechanism (#9300)

This commit is contained in:
Luka Govedič
2024-10-17 14:36:37 -04:00
committed by GitHub
parent 7871659abb
commit 0f41fbe5a3
8 changed files with 220 additions and 21 deletions

View File

@@ -7,6 +7,7 @@ import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
"""Root mean square normalization.
@@ -122,6 +123,7 @@ class RMSNorm(CustomOp):
return s
@CustomOp.register("gemma_rms_norm")
class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma.