[Rocm][torch.compile] Adding layernorm + fp8 block quant and silu + fp8 block quant for Aiter (#25693)
Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Micah Williamson <micah.williamson@amd.com> Signed-off-by: Charlie Fu <Charlie.Fu@amd.com> Co-authored-by: Micah Williamson <micah.williamson@amd.com> Co-authored-by: wuhuikx <hattie.wu@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import functools
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var
|
||||
from .post_cleanup import PostCleanupPass
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from vllm.compilation.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormFp8GroupQuantFusionPass,
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion import RMSNormQuantFusionPass
|
||||
@@ -109,8 +116,12 @@ class PostGradPassManager(CustomGraphPass):
|
||||
|
||||
if self.pass_config.fuse_norm_quant:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
|
||||
if self.pass_config.fuse_act_quant:
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_attn_quant:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
Reference in New Issue
Block a user