[ROCm]: Update rope+kvcache fusion conditions and disable custom op by default (#36716)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan Potdar
2026-03-25 15:58:44 -05:00
committed by GitHub
parent 70a2152830
commit a0e8c74005
5 changed files with 42 additions and 18 deletions

View File

@@ -22,7 +22,7 @@ or just on the low or high end.
| ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ | | ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ |
| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low | | [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low |
| [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always | | [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always |
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O1 (ROCm/AITER only) | TBD | No | Low | | [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O2 (ROCm/AITER only) | 2-4% | No | Low |
| [QK Norm + RoPE](#qk-norm--rope-enable_qk_norm_rope_fusion) | `enable_qk_norm_rope_fusion` | Q/K RMSNorm → rotary embedding | Off by default | 2-3% | No | Low | | [QK Norm + RoPE](#qk-norm--rope-enable_qk_norm_rope_fusion) | `enable_qk_norm_rope_fusion` | Q/K RMSNorm → rotary embedding | Off by default | 2-3% | No | Low |
| [Sequence Parallelism](#sequence-parallelism-enable_sp) | `enable_sp` | AllReduce → ReduceScatter + AllGather | Off by default | Prereq for AsyncTP | Yes | High | | [Sequence Parallelism](#sequence-parallelism-enable_sp) | `enable_sp` | AllReduce → ReduceScatter + AllGather | Off by default | Prereq for AsyncTP | Yes | High |
| [AsyncTP GEMM + collective](#asynctp-gemm--collective-overlap-fuse_gemm_comms) | `fuse_gemm_comms` | GEMM → reduce-scatter / all-gather → GEMM | Off by default | 7-10% | Yes | High | | [AsyncTP GEMM + collective](#asynctp-gemm--collective-overlap-fuse_gemm_comms) | `fuse_gemm_comms` | GEMM → reduce-scatter / all-gather → GEMM | Off by default | 7-10% | Yes | High |

View File

@@ -56,7 +56,6 @@ Fusions:
- `-cc.pass_config.fuse_norm_quant=True`* - `-cc.pass_config.fuse_norm_quant=True`*
- `-cc.pass_config.fuse_act_quant=True`* - `-cc.pass_config.fuse_act_quant=True`*
- `-cc.pass_config.fuse_act_padding=True` - `-cc.pass_config.fuse_act_padding=True`
- `-cc.pass_config.fuse_rope_kvcache=True`† (will be moved to O2)
\* These fusions are only enabled when either op is using a custom kernel, otherwise Inductor fusion is better.</br> \* These fusions are only enabled when either op is using a custom kernel, otherwise Inductor fusion is better.</br>
† These fusions are ROCm-only and require AITER. † These fusions are ROCm-only and require AITER.
@@ -71,6 +70,9 @@ Settings (on top of `-O1`):
- `-cc.cudagraph_mode=FULL_AND_PIECEWISE` - `-cc.cudagraph_mode=FULL_AND_PIECEWISE`
- `-cc.pass_config.fuse_allreduce_rms=True` - `-cc.pass_config.fuse_allreduce_rms=True`
- `-cc.pass_config.fuse_rope_kvcache=True`
† These fusions are ROCm-only and require AITER.
### `-O3`: Aggressive Optimization ### `-O3`: Aggressive Optimization

View File

@@ -1042,13 +1042,6 @@ class CompilationConfig:
self.splitting_ops = [] self.splitting_ops = []
return return
# NOTE: this function needs to be called only when mode is
# CompilationMode.VLLM_COMPILE
assert self.mode == CompilationMode.VLLM_COMPILE, (
"set_splitting_ops_for_v1 should only be called when "
"mode is CompilationMode.VLLM_COMPILE"
)
if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition: if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition:
self.set_splitting_ops_for_attn_fusion() self.set_splitting_ops_for_attn_fusion()
else: else:
@@ -1069,6 +1062,16 @@ class CompilationConfig:
# that doesn't seem to affect performance. # that doesn't seem to affect performance.
# https://github.com/vllm-project/vllm/issues/33267 # https://github.com/vllm-project/vllm/issues/33267
if not self.use_inductor_graph_partition: if not self.use_inductor_graph_partition:
if self.pass_config.fuse_rope_kvcache:
logger.warning_once(
"fuse_rope_kvcache is enabled, but splitting_ops is None "
"and Inductor graph partition is not enabled."
"Disabling fuse_rope_kvcache."
"Please either set splitting_ops to an empty list []"
"or set use_inductor_graph_partition to True "
"to enable RoPE+KV cache fusion."
)
self.pass_config.fuse_rope_kvcache = False
self.splitting_ops.append("vllm::unified_kv_cache_update") self.splitting_ops.append("vllm::unified_kv_cache_update")
self.splitting_ops.append("vllm::unified_mla_kv_cache_update") self.splitting_ops.append("vllm::unified_mla_kv_cache_update")
@@ -1142,6 +1145,29 @@ class CompilationConfig:
op in self.splitting_ops for op in self._attention_ops op in self.splitting_ops for op in self._attention_ops
) )
def splitting_ops_contain_kv_cache_update(self) -> bool:
# when using Dynamo partition while splitting ops is None
# and attn+quant fusion disabled, the kv_cache_update_ops are
# appended to splitting_ops in set_splitting_ops_for_v1 due to
# https://github.com/vllm-project/vllm/issues/33267
# In this case, we return True if the kv_cache_update_ops
# are not in the splitting_ops yet, but will subsequently
# be added to splitting_ops.
if (
not self.use_inductor_graph_partition
and self.splitting_ops is None
and not self.pass_config.fuse_attn_quant
):
return True
kv_cache_update_ops = [
"vllm::unified_kv_cache_update",
"vllm::unified_mla_kv_cache_update",
]
return self.splitting_ops is not None and all(
op in self.splitting_ops for op in kv_cache_update_ops
)
def is_attention_compiled_piecewise(self) -> bool: def is_attention_compiled_piecewise(self) -> bool:
if not self.splitting_ops_contain_attention(): if not self.splitting_ops_contain_attention():
return False return False

View File

@@ -144,7 +144,10 @@ def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool:
return ( return (
rocm_aiter_ops.is_enabled() rocm_aiter_ops.is_enabled()
and cfg.compilation_config.is_custom_op_enabled("rotary_embedding") and cfg.compilation_config.is_custom_op_enabled("rotary_embedding")
and cfg.compilation_config.use_inductor_graph_partition and (
cfg.compilation_config.use_inductor_graph_partition
or not cfg.compilation_config.splitting_ops_contain_kv_cache_update()
)
) )
@@ -190,7 +193,7 @@ OPTIMIZATION_LEVEL_01 = {
"enable_sp": False, "enable_sp": False,
"fuse_gemm_comms": False, "fuse_gemm_comms": False,
"fuse_act_padding": enable_norm_pad_fusion, "fuse_act_padding": enable_norm_pad_fusion,
"fuse_rope_kvcache": enable_rope_kvcache_fusion, "fuse_rope_kvcache": False,
}, },
"cudagraph_mode": CUDAGraphMode.PIECEWISE, "cudagraph_mode": CUDAGraphMode.PIECEWISE,
"use_inductor_graph_partition": False, "use_inductor_graph_partition": False,

View File

@@ -649,13 +649,6 @@ class RocmPlatform(Platform):
and "-grouped_topk" not in compilation_config.custom_ops and "-grouped_topk" not in compilation_config.custom_ops
): ):
compilation_config.custom_ops.append("+grouped_topk") compilation_config.custom_ops.append("+grouped_topk")
# Enable rotary embedding customop when using AITER if not disabled by user
if (
rocm_aiter_ops.is_enabled()
and "+rotary_embedding" not in compilation_config.custom_ops
and "-rotary_embedding" not in compilation_config.custom_ops
):
compilation_config.custom_ops.append("+rotary_embedding")
# Default dispatch to rocm's sparse_attn_indexer implementation # Default dispatch to rocm's sparse_attn_indexer implementation
compilation_config.custom_ops.append("+sparse_attn_indexer") compilation_config.custom_ops.append("+sparse_attn_indexer")