From a0e8c74005f5782fad30be912c12cd37fc2813e9 Mon Sep 17 00:00:00 2001 From: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Date: Wed, 25 Mar 2026 15:58:44 -0500 Subject: [PATCH] [ROCm]: Update rope+kvcache fusion conditions and disable custom op by default (#36716) Signed-off-by: Rohan138 --- docs/design/fusions.md | 2 +- docs/design/optimization_levels.md | 4 ++- vllm/config/compilation.py | 40 ++++++++++++++++++++++++------ vllm/config/vllm.py | 7 ++++-- vllm/platforms/rocm.py | 7 ------ 5 files changed, 42 insertions(+), 18 deletions(-) diff --git a/docs/design/fusions.md b/docs/design/fusions.md index 26eb95c9d..28a29a7f3 100644 --- a/docs/design/fusions.md +++ b/docs/design/fusions.md @@ -22,7 +22,7 @@ or just on the low or high end. | ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ | | [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low | | [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always | -| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O1 (ROCm/AITER only) | TBD | No | Low | +| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O2 (ROCm/AITER only) | 2-4% | No | Low | | [QK Norm + RoPE](#qk-norm--rope-enable_qk_norm_rope_fusion) | `enable_qk_norm_rope_fusion` | Q/K RMSNorm → rotary embedding | Off by default | 2-3% | No | Low | | [Sequence Parallelism](#sequence-parallelism-enable_sp) | `enable_sp` | AllReduce → ReduceScatter + AllGather | Off by default | Prereq for AsyncTP | Yes | High | | [AsyncTP GEMM + collective](#asynctp-gemm--collective-overlap-fuse_gemm_comms) | `fuse_gemm_comms` | GEMM → reduce-scatter / all-gather → GEMM | Off by default | 7-10% | Yes | High | diff --git a/docs/design/optimization_levels.md b/docs/design/optimization_levels.md index 91af515f4..591978b54 100644 --- a/docs/design/optimization_levels.md +++ b/docs/design/optimization_levels.md @@ -56,7 +56,6 @@ Fusions: - `-cc.pass_config.fuse_norm_quant=True`* - `-cc.pass_config.fuse_act_quant=True`* - `-cc.pass_config.fuse_act_padding=True`† -- `-cc.pass_config.fuse_rope_kvcache=True`† (will be moved to O2) \* These fusions are only enabled when either op is using a custom kernel, otherwise Inductor fusion is better.
† These fusions are ROCm-only and require AITER. @@ -71,6 +70,9 @@ Settings (on top of `-O1`): - `-cc.cudagraph_mode=FULL_AND_PIECEWISE` - `-cc.pass_config.fuse_allreduce_rms=True` +- `-cc.pass_config.fuse_rope_kvcache=True`† + +† These fusions are ROCm-only and require AITER. ### `-O3`: Aggressive Optimization diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ec2f7e7f7..5b6648908 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -1042,13 +1042,6 @@ class CompilationConfig: self.splitting_ops = [] return - # NOTE: this function needs to be called only when mode is - # CompilationMode.VLLM_COMPILE - assert self.mode == CompilationMode.VLLM_COMPILE, ( - "set_splitting_ops_for_v1 should only be called when " - "mode is CompilationMode.VLLM_COMPILE" - ) - if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition: self.set_splitting_ops_for_attn_fusion() else: @@ -1069,6 +1062,16 @@ class CompilationConfig: # that doesn't seem to affect performance. # https://github.com/vllm-project/vllm/issues/33267 if not self.use_inductor_graph_partition: + if self.pass_config.fuse_rope_kvcache: + logger.warning_once( + "fuse_rope_kvcache is enabled, but splitting_ops is None " + "and Inductor graph partition is not enabled." + "Disabling fuse_rope_kvcache." + "Please either set splitting_ops to an empty list []" + "or set use_inductor_graph_partition to True " + "to enable RoPE+KV cache fusion." + ) + self.pass_config.fuse_rope_kvcache = False self.splitting_ops.append("vllm::unified_kv_cache_update") self.splitting_ops.append("vllm::unified_mla_kv_cache_update") @@ -1142,6 +1145,29 @@ class CompilationConfig: op in self.splitting_ops for op in self._attention_ops ) + def splitting_ops_contain_kv_cache_update(self) -> bool: + # when using Dynamo partition while splitting ops is None + # and attn+quant fusion disabled, the kv_cache_update_ops are + # appended to splitting_ops in set_splitting_ops_for_v1 due to + # https://github.com/vllm-project/vllm/issues/33267 + # In this case, we return True if the kv_cache_update_ops + # are not in the splitting_ops yet, but will subsequently + # be added to splitting_ops. + if ( + not self.use_inductor_graph_partition + and self.splitting_ops is None + and not self.pass_config.fuse_attn_quant + ): + return True + + kv_cache_update_ops = [ + "vllm::unified_kv_cache_update", + "vllm::unified_mla_kv_cache_update", + ] + return self.splitting_ops is not None and all( + op in self.splitting_ops for op in kv_cache_update_ops + ) + def is_attention_compiled_piecewise(self) -> bool: if not self.splitting_ops_contain_attention(): return False diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 65a78f4d0..ae92a19dd 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -144,7 +144,10 @@ def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool: return ( rocm_aiter_ops.is_enabled() and cfg.compilation_config.is_custom_op_enabled("rotary_embedding") - and cfg.compilation_config.use_inductor_graph_partition + and ( + cfg.compilation_config.use_inductor_graph_partition + or not cfg.compilation_config.splitting_ops_contain_kv_cache_update() + ) ) @@ -190,7 +193,7 @@ OPTIMIZATION_LEVEL_01 = { "enable_sp": False, "fuse_gemm_comms": False, "fuse_act_padding": enable_norm_pad_fusion, - "fuse_rope_kvcache": enable_rope_kvcache_fusion, + "fuse_rope_kvcache": False, }, "cudagraph_mode": CUDAGraphMode.PIECEWISE, "use_inductor_graph_partition": False, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f75af323a..2f76aedbf 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -649,13 +649,6 @@ class RocmPlatform(Platform): and "-grouped_topk" not in compilation_config.custom_ops ): compilation_config.custom_ops.append("+grouped_topk") - # Enable rotary embedding customop when using AITER if not disabled by user - if ( - rocm_aiter_ops.is_enabled() - and "+rotary_embedding" not in compilation_config.custom_ops - and "-rotary_embedding" not in compilation_config.custom_ops - ): - compilation_config.custom_ops.append("+rotary_embedding") # Default dispatch to rocm's sparse_attn_indexer implementation compilation_config.custom_ops.append("+sparse_attn_indexer")