[ROCm]: Update rope+kvcache fusion conditions and disable custom op by default (#36716)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
@@ -22,7 +22,7 @@ or just on the low or high end.
|
|||||||
| ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ |
|
| ------------------------------------------------------------------------------ | ---------------------------- | ---------------------------------------------- | ------------------------------ | ------------------ | --------- | ------------ |
|
||||||
| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low |
|
| [AllReduce + RMSNorm](#allreduce--rmsnorm-fuse_allreduce_rms) | `fuse_allreduce_rms` | All-reduce → RMSNorm (+residual_add) (→ quant) | O2 (Hopper/Blackwell + TP > 1) | 5-20% | No | Low |
|
||||||
| [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always |
|
| [Attention + Quant](#attention--quantization-fuse_attn_quant) | `fuse_attn_quant` | Attention output → FP8/NVFP4 quant | Off by default | 3-7% | Yes | Always |
|
||||||
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O1 (ROCm/AITER only) | TBD | No | Low |
|
| [RoPE + KV-Cache Update](#rope--kv-cache-update-fuse_rope_kvcache) | `fuse_rope_kvcache` | Rotary embedding → KV cache write | O2 (ROCm/AITER only) | 2-4% | No | Low |
|
||||||
| [QK Norm + RoPE](#qk-norm--rope-enable_qk_norm_rope_fusion) | `enable_qk_norm_rope_fusion` | Q/K RMSNorm → rotary embedding | Off by default | 2-3% | No | Low |
|
| [QK Norm + RoPE](#qk-norm--rope-enable_qk_norm_rope_fusion) | `enable_qk_norm_rope_fusion` | Q/K RMSNorm → rotary embedding | Off by default | 2-3% | No | Low |
|
||||||
| [Sequence Parallelism](#sequence-parallelism-enable_sp) | `enable_sp` | AllReduce → ReduceScatter + AllGather | Off by default | Prereq for AsyncTP | Yes | High |
|
| [Sequence Parallelism](#sequence-parallelism-enable_sp) | `enable_sp` | AllReduce → ReduceScatter + AllGather | Off by default | Prereq for AsyncTP | Yes | High |
|
||||||
| [AsyncTP GEMM + collective](#asynctp-gemm--collective-overlap-fuse_gemm_comms) | `fuse_gemm_comms` | GEMM → reduce-scatter / all-gather → GEMM | Off by default | 7-10% | Yes | High |
|
| [AsyncTP GEMM + collective](#asynctp-gemm--collective-overlap-fuse_gemm_comms) | `fuse_gemm_comms` | GEMM → reduce-scatter / all-gather → GEMM | Off by default | 7-10% | Yes | High |
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ Fusions:
|
|||||||
- `-cc.pass_config.fuse_norm_quant=True`*
|
- `-cc.pass_config.fuse_norm_quant=True`*
|
||||||
- `-cc.pass_config.fuse_act_quant=True`*
|
- `-cc.pass_config.fuse_act_quant=True`*
|
||||||
- `-cc.pass_config.fuse_act_padding=True`†
|
- `-cc.pass_config.fuse_act_padding=True`†
|
||||||
- `-cc.pass_config.fuse_rope_kvcache=True`† (will be moved to O2)
|
|
||||||
|
|
||||||
\* These fusions are only enabled when either op is using a custom kernel, otherwise Inductor fusion is better.</br>
|
\* These fusions are only enabled when either op is using a custom kernel, otherwise Inductor fusion is better.</br>
|
||||||
† These fusions are ROCm-only and require AITER.
|
† These fusions are ROCm-only and require AITER.
|
||||||
@@ -71,6 +70,9 @@ Settings (on top of `-O1`):
|
|||||||
|
|
||||||
- `-cc.cudagraph_mode=FULL_AND_PIECEWISE`
|
- `-cc.cudagraph_mode=FULL_AND_PIECEWISE`
|
||||||
- `-cc.pass_config.fuse_allreduce_rms=True`
|
- `-cc.pass_config.fuse_allreduce_rms=True`
|
||||||
|
- `-cc.pass_config.fuse_rope_kvcache=True`†
|
||||||
|
|
||||||
|
† These fusions are ROCm-only and require AITER.
|
||||||
|
|
||||||
### `-O3`: Aggressive Optimization
|
### `-O3`: Aggressive Optimization
|
||||||
|
|
||||||
|
|||||||
@@ -1042,13 +1042,6 @@ class CompilationConfig:
|
|||||||
self.splitting_ops = []
|
self.splitting_ops = []
|
||||||
return
|
return
|
||||||
|
|
||||||
# NOTE: this function needs to be called only when mode is
|
|
||||||
# CompilationMode.VLLM_COMPILE
|
|
||||||
assert self.mode == CompilationMode.VLLM_COMPILE, (
|
|
||||||
"set_splitting_ops_for_v1 should only be called when "
|
|
||||||
"mode is CompilationMode.VLLM_COMPILE"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition:
|
if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition:
|
||||||
self.set_splitting_ops_for_attn_fusion()
|
self.set_splitting_ops_for_attn_fusion()
|
||||||
else:
|
else:
|
||||||
@@ -1069,6 +1062,16 @@ class CompilationConfig:
|
|||||||
# that doesn't seem to affect performance.
|
# that doesn't seem to affect performance.
|
||||||
# https://github.com/vllm-project/vllm/issues/33267
|
# https://github.com/vllm-project/vllm/issues/33267
|
||||||
if not self.use_inductor_graph_partition:
|
if not self.use_inductor_graph_partition:
|
||||||
|
if self.pass_config.fuse_rope_kvcache:
|
||||||
|
logger.warning_once(
|
||||||
|
"fuse_rope_kvcache is enabled, but splitting_ops is None "
|
||||||
|
"and Inductor graph partition is not enabled."
|
||||||
|
"Disabling fuse_rope_kvcache."
|
||||||
|
"Please either set splitting_ops to an empty list []"
|
||||||
|
"or set use_inductor_graph_partition to True "
|
||||||
|
"to enable RoPE+KV cache fusion."
|
||||||
|
)
|
||||||
|
self.pass_config.fuse_rope_kvcache = False
|
||||||
self.splitting_ops.append("vllm::unified_kv_cache_update")
|
self.splitting_ops.append("vllm::unified_kv_cache_update")
|
||||||
self.splitting_ops.append("vllm::unified_mla_kv_cache_update")
|
self.splitting_ops.append("vllm::unified_mla_kv_cache_update")
|
||||||
|
|
||||||
@@ -1142,6 +1145,29 @@ class CompilationConfig:
|
|||||||
op in self.splitting_ops for op in self._attention_ops
|
op in self.splitting_ops for op in self._attention_ops
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def splitting_ops_contain_kv_cache_update(self) -> bool:
|
||||||
|
# when using Dynamo partition while splitting ops is None
|
||||||
|
# and attn+quant fusion disabled, the kv_cache_update_ops are
|
||||||
|
# appended to splitting_ops in set_splitting_ops_for_v1 due to
|
||||||
|
# https://github.com/vllm-project/vllm/issues/33267
|
||||||
|
# In this case, we return True if the kv_cache_update_ops
|
||||||
|
# are not in the splitting_ops yet, but will subsequently
|
||||||
|
# be added to splitting_ops.
|
||||||
|
if (
|
||||||
|
not self.use_inductor_graph_partition
|
||||||
|
and self.splitting_ops is None
|
||||||
|
and not self.pass_config.fuse_attn_quant
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
kv_cache_update_ops = [
|
||||||
|
"vllm::unified_kv_cache_update",
|
||||||
|
"vllm::unified_mla_kv_cache_update",
|
||||||
|
]
|
||||||
|
return self.splitting_ops is not None and all(
|
||||||
|
op in self.splitting_ops for op in kv_cache_update_ops
|
||||||
|
)
|
||||||
|
|
||||||
def is_attention_compiled_piecewise(self) -> bool:
|
def is_attention_compiled_piecewise(self) -> bool:
|
||||||
if not self.splitting_ops_contain_attention():
|
if not self.splitting_ops_contain_attention():
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -144,7 +144,10 @@ def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool:
|
|||||||
return (
|
return (
|
||||||
rocm_aiter_ops.is_enabled()
|
rocm_aiter_ops.is_enabled()
|
||||||
and cfg.compilation_config.is_custom_op_enabled("rotary_embedding")
|
and cfg.compilation_config.is_custom_op_enabled("rotary_embedding")
|
||||||
and cfg.compilation_config.use_inductor_graph_partition
|
and (
|
||||||
|
cfg.compilation_config.use_inductor_graph_partition
|
||||||
|
or not cfg.compilation_config.splitting_ops_contain_kv_cache_update()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -190,7 +193,7 @@ OPTIMIZATION_LEVEL_01 = {
|
|||||||
"enable_sp": False,
|
"enable_sp": False,
|
||||||
"fuse_gemm_comms": False,
|
"fuse_gemm_comms": False,
|
||||||
"fuse_act_padding": enable_norm_pad_fusion,
|
"fuse_act_padding": enable_norm_pad_fusion,
|
||||||
"fuse_rope_kvcache": enable_rope_kvcache_fusion,
|
"fuse_rope_kvcache": False,
|
||||||
},
|
},
|
||||||
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
||||||
"use_inductor_graph_partition": False,
|
"use_inductor_graph_partition": False,
|
||||||
|
|||||||
@@ -649,13 +649,6 @@ class RocmPlatform(Platform):
|
|||||||
and "-grouped_topk" not in compilation_config.custom_ops
|
and "-grouped_topk" not in compilation_config.custom_ops
|
||||||
):
|
):
|
||||||
compilation_config.custom_ops.append("+grouped_topk")
|
compilation_config.custom_ops.append("+grouped_topk")
|
||||||
# Enable rotary embedding customop when using AITER if not disabled by user
|
|
||||||
if (
|
|
||||||
rocm_aiter_ops.is_enabled()
|
|
||||||
and "+rotary_embedding" not in compilation_config.custom_ops
|
|
||||||
and "-rotary_embedding" not in compilation_config.custom_ops
|
|
||||||
):
|
|
||||||
compilation_config.custom_ops.append("+rotary_embedding")
|
|
||||||
|
|
||||||
# Default dispatch to rocm's sparse_attn_indexer implementation
|
# Default dispatch to rocm's sparse_attn_indexer implementation
|
||||||
compilation_config.custom_ops.append("+sparse_attn_indexer")
|
compilation_config.custom_ops.append("+sparse_attn_indexer")
|
||||||
|
|||||||
Reference in New Issue
Block a user