[torch.compile] Enable AR+rms fusion by default available for -O2 (#34299)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
@@ -115,7 +115,7 @@ class PassConfig:
|
||||
"""Fuse the custom SiluMul + quant ops."""
|
||||
fuse_attn_quant: bool = Field(default=None)
|
||||
"""Fuse the custom attention + quant ops."""
|
||||
eliminate_noops: bool = Field(default=None)
|
||||
eliminate_noops: bool = Field(default=True)
|
||||
"""Eliminate no-op ops."""
|
||||
enable_sp: bool = Field(default=None)
|
||||
"""Enable sequence parallelism."""
|
||||
@@ -194,7 +194,6 @@ class PassConfig:
|
||||
"fuse_norm_quant",
|
||||
"fuse_act_quant",
|
||||
"fuse_attn_quant",
|
||||
"eliminate_noops",
|
||||
"enable_sp",
|
||||
"fuse_gemm_comms",
|
||||
"fuse_allreduce_rms",
|
||||
|
||||
@@ -102,6 +102,19 @@ def enable_act_fusion(cfg: "VllmConfig") -> bool:
|
||||
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
|
||||
|
||||
|
||||
def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
|
||||
"""Enable if TP > 1 and Hopper+ and flashinfer installed."""
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
return (
|
||||
cfg.parallel_config.tensor_parallel_size > 1
|
||||
and current_platform.is_cuda()
|
||||
and current_platform.has_device_capability(90)
|
||||
and has_flashinfer()
|
||||
)
|
||||
|
||||
|
||||
def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
|
||||
"""Enable if using AITER RMSNorm and AITER Triton GEMMs
|
||||
and hidden size is 2880 i.e. gpt-oss; otherwise Inductor handles fusion."""
|
||||
@@ -118,7 +131,6 @@ def enable_norm_pad_fusion(cfg: "VllmConfig") -> bool:
|
||||
OPTIMIZATION_LEVEL_00 = {
|
||||
"compilation_config": {
|
||||
"pass_config": {
|
||||
"eliminate_noops": False,
|
||||
"fuse_norm_quant": False,
|
||||
"fuse_act_quant": False,
|
||||
"fuse_allreduce_rms": False,
|
||||
@@ -137,7 +149,6 @@ OPTIMIZATION_LEVEL_00 = {
|
||||
OPTIMIZATION_LEVEL_01 = {
|
||||
"compilation_config": {
|
||||
"pass_config": {
|
||||
"eliminate_noops": True,
|
||||
"fuse_norm_quant": enable_norm_fusion,
|
||||
"fuse_act_quant": enable_act_fusion,
|
||||
"fuse_allreduce_rms": False,
|
||||
@@ -156,10 +167,9 @@ OPTIMIZATION_LEVEL_01 = {
|
||||
OPTIMIZATION_LEVEL_02 = {
|
||||
"compilation_config": {
|
||||
"pass_config": {
|
||||
"eliminate_noops": True,
|
||||
"fuse_norm_quant": enable_norm_fusion,
|
||||
"fuse_act_quant": enable_act_fusion,
|
||||
"fuse_allreduce_rms": False,
|
||||
"fuse_allreduce_rms": enable_allreduce_rms_fusion,
|
||||
"fuse_attn_quant": IS_QUANTIZED,
|
||||
"enable_sp": IS_DENSE,
|
||||
"fuse_gemm_comms": IS_DENSE,
|
||||
@@ -175,10 +185,9 @@ OPTIMIZATION_LEVEL_02 = {
|
||||
OPTIMIZATION_LEVEL_03 = {
|
||||
"compilation_config": {
|
||||
"pass_config": {
|
||||
"eliminate_noops": True,
|
||||
"fuse_norm_quant": enable_norm_fusion,
|
||||
"fuse_act_quant": enable_act_fusion,
|
||||
"fuse_allreduce_rms": False,
|
||||
"fuse_allreduce_rms": enable_allreduce_rms_fusion,
|
||||
"fuse_attn_quant": IS_QUANTIZED,
|
||||
"enable_sp": IS_DENSE,
|
||||
"fuse_gemm_comms": IS_DENSE,
|
||||
|
||||
Reference in New Issue
Block a user