[ROCm]: Update rope+kvcache fusion conditions and disable custom op by default (#36716)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -22,7 +22,7 @@ or just on the low or high end.
|
||||
| ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ |
|
||||
| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low |
|
||||
| [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always |
|
||||
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O1 (ROCm/AITER only) | TBD | No | Low |
|
||||
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O2 (ROCm/AITER only) | 2-4% | No | Low |
|
||||
| [QK Norm + RoPE](#qk-norm--rope-enable_qk_norm_rope_fusion) | `enable_qk_norm_rope_fusion` | Q/K RMSNorm → rotary embedding | Off by default | 2-3% | No | Low |
|
||||
| [Sequence Parallelism](#sequence-parallelism-enable_sp) | `enable_sp` | AllReduce → ReduceScatter + AllGather | Off by default | Prereq for AsyncTP | Yes | High |
|
||||
| [AsyncTP GEMM + collective](#asynctp-gemm--collective-overlap-fuse_gemm_comms) | `fuse_gemm_comms` | GEMM → reduce-scatter / all-gather → GEMM | Off by default | 7-10% | Yes | High |
|
||||
|
||||
@@ -56,7 +56,6 @@ Fusions:
|
||||
- `-cc.pass_config.fuse_norm_quant=True`*
|
||||
- `-cc.pass_config.fuse_act_quant=True`*
|
||||
- `-cc.pass_config.fuse_act_padding=True`†
|
||||
- `-cc.pass_config.fuse_rope_kvcache=True`† (will be moved to O2)
|
||||
|
||||
\* These fusions are only enabled when either op is using a custom kernel, otherwise Inductor fusion is better.</br>
|
||||
† These fusions are ROCm-only and require AITER.
|
||||
@@ -71,6 +70,9 @@ Settings (on top of `-O1`):
|
||||
|
||||
- `-cc.cudagraph_mode=FULL_AND_PIECEWISE`
|
||||
- `-cc.pass_config.fuse_allreduce_rms=True`
|
||||
- `-cc.pass_config.fuse_rope_kvcache=True`†
|
||||
|
||||
† These fusions are ROCm-only and require AITER.
|
||||
|
||||
### `-O3`: Aggressive Optimization
|
||||
|
||||
|
||||
@@ -1042,13 +1042,6 @@ class CompilationConfig:
|
||||
self.splitting_ops = []
|
||||
return
|
||||
|
||||
# NOTE: this function needs to be called only when mode is
|
||||
# CompilationMode.VLLM_COMPILE
|
||||
assert self.mode == CompilationMode.VLLM_COMPILE, (
|
||||
"set_splitting_ops_for_v1 should only be called when "
|
||||
"mode is CompilationMode.VLLM_COMPILE"
|
||||
)
|
||||
|
||||
if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition:
|
||||
self.set_splitting_ops_for_attn_fusion()
|
||||
else:
|
||||
@@ -1069,6 +1062,16 @@ class CompilationConfig:
|
||||
# that doesn't seem to affect performance.
|
||||
# https://github.com/vllm-project/vllm/issues/33267
|
||||
if not self.use_inductor_graph_partition:
|
||||
if self.pass_config.fuse_rope_kvcache:
|
||||
logger.warning_once(
|
||||
"fuse_rope_kvcache is enabled, but splitting_ops is None "
|
||||
"and Inductor graph partition is not enabled."
|
||||
"Disabling fuse_rope_kvcache."
|
||||
"Please either set splitting_ops to an empty list []"
|
||||
"or set use_inductor_graph_partition to True "
|
||||
"to enable RoPE+KV cache fusion."
|
||||
)
|
||||
self.pass_config.fuse_rope_kvcache = False
|
||||
self.splitting_ops.append("vllm::unified_kv_cache_update")
|
||||
self.splitting_ops.append("vllm::unified_mla_kv_cache_update")
|
||||
|
||||
@@ -1142,6 +1145,29 @@ class CompilationConfig:
|
||||
op in self.splitting_ops for op in self._attention_ops
|
||||
)
|
||||
|
||||
def splitting_ops_contain_kv_cache_update(self) -> bool:
|
||||
# when using Dynamo partition while splitting ops is None
|
||||
# and attn+quant fusion disabled, the kv_cache_update_ops are
|
||||
# appended to splitting_ops in set_splitting_ops_for_v1 due to
|
||||
# https://github.com/vllm-project/vllm/issues/33267
|
||||
# In this case, we return True if the kv_cache_update_ops
|
||||
# are not in the splitting_ops yet, but will subsequently
|
||||
# be added to splitting_ops.
|
||||
if (
|
||||
not self.use_inductor_graph_partition
|
||||
and self.splitting_ops is None
|
||||
and not self.pass_config.fuse_attn_quant
|
||||
):
|
||||
return True
|
||||
|
||||
kv_cache_update_ops = [
|
||||
"vllm::unified_kv_cache_update",
|
||||
"vllm::unified_mla_kv_cache_update",
|
||||
]
|
||||
return self.splitting_ops is not None and all(
|
||||
op in self.splitting_ops for op in kv_cache_update_ops
|
||||
)
|
||||
|
||||
def is_attention_compiled_piecewise(self) -> bool:
|
||||
if not self.splitting_ops_contain_attention():
|
||||
return False
|
||||
|
||||
@@ -144,7 +144,10 @@ def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool:
|
||||
return (
|
||||
rocm_aiter_ops.is_enabled()
|
||||
and cfg.compilation_config.is_custom_op_enabled("rotary_embedding")
|
||||
and cfg.compilation_config.use_inductor_graph_partition
|
||||
and (
|
||||
cfg.compilation_config.use_inductor_graph_partition
|
||||
or not cfg.compilation_config.splitting_ops_contain_kv_cache_update()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -190,7 +193,7 @@ OPTIMIZATION_LEVEL_01 = {
|
||||
"enable_sp": False,
|
||||
"fuse_gemm_comms": False,
|
||||
"fuse_act_padding": enable_norm_pad_fusion,
|
||||
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
|
||||
"fuse_rope_kvcache": False,
|
||||
},
|
||||
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
||||
"use_inductor_graph_partition": False,
|
||||
|
||||
@@ -649,13 +649,6 @@ class RocmPlatform(Platform):
|
||||
and "-grouped_topk" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+grouped_topk")
|
||||
# Enable rotary embedding customop when using AITER if not disabled by user
|
||||
if (
|
||||
rocm_aiter_ops.is_enabled()
|
||||
and "+rotary_embedding" not in compilation_config.custom_ops
|
||||
and "-rotary_embedding" not in compilation_config.custom_ops
|
||||
):
|
||||
compilation_config.custom_ops.append("+rotary_embedding")
|
||||
|
||||
# Default dispatch to rocm's sparse_attn_indexer implementation
|
||||
compilation_config.custom_ops.append("+sparse_attn_indexer")
|
||||
|
||||
Reference in New Issue
Block a user