[torch.compile] Compile CustomOp.forward_native for SiluAndMul and QuantFP8 to avoid raw torch ops inside opaque custom ops (#32806)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -75,8 +75,8 @@ class SiluAndMul(CustomOp):
|
||||
|
||||
# --8<-- [end:silu_and_mul]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, *, compile_native: bool = True):
|
||||
super().__init__(compile_native=compile_native)
|
||||
if current_platform.is_cuda_alike():
|
||||
self.op = torch.ops._C.silu_and_mul
|
||||
elif current_platform.is_xpu():
|
||||
|
||||
@@ -38,6 +38,7 @@ class QuantFP8(CustomOp):
|
||||
column_major_scales: bool = False,
|
||||
tma_aligned_scales: bool = False,
|
||||
use_ue8m0: bool | None = None, # for Torch compile
|
||||
compile_native: bool = True,
|
||||
):
|
||||
"""
|
||||
:param static: static or dynamic quantization
|
||||
@@ -49,8 +50,9 @@ class QuantFP8(CustomOp):
|
||||
TMA-aligned layout
|
||||
:param column_major_scales: For group quantization, output scales in
|
||||
column major format
|
||||
:param compile_native: Manually compile forward_native if compile mode > None
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(compile_native=compile_native)
|
||||
self.static = static
|
||||
self.group_shape = group_shape
|
||||
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
|
||||
|
||||
@@ -129,7 +129,7 @@ class ApplyRotaryEmb(CustomOp):
|
||||
is_neox_style: bool = True,
|
||||
enable_fp32_compute: bool = False,
|
||||
) -> None:
|
||||
super().__init__(enforce_enable)
|
||||
super().__init__(enforce_enable=enforce_enable)
|
||||
self.is_neox_style = is_neox_style
|
||||
self.enable_fp32_compute = enable_fp32_compute
|
||||
|
||||
|
||||
Reference in New Issue
Block a user