[ROCm] AITER fused RoPE+KVCache (#33443)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com> Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Co-authored-by: charlifu <charlifu@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
This commit is contained in:
@@ -127,6 +127,13 @@ class PassConfig:
|
||||
# ROCm/AITER specific fusions
|
||||
fuse_act_padding: bool = Field(default=None)
|
||||
"""Fuse the custom RMSNorm + padding ops."""
|
||||
fuse_rope_kvcache: bool = Field(default=None)
|
||||
"""Fuse the QK rope + KV cache ops."""
|
||||
|
||||
rope_kvcache_fusion_max_token_num: int = 256
|
||||
"""The threshold for ROCm AITER RoPE+KVCache fusion e.g. for small batch decode.
|
||||
Larger batch sizes e.g. during prefill will use the unfused kernels.
|
||||
"""
|
||||
|
||||
fi_allreduce_fusion_max_size_mb: float | None = None
|
||||
"""The threshold of the communicated tensor sizes under which
|
||||
@@ -198,6 +205,7 @@ class PassConfig:
|
||||
"fuse_gemm_comms",
|
||||
"fuse_allreduce_rms",
|
||||
"fuse_act_padding",
|
||||
"fuse_rope_kvcache",
|
||||
mode="wrap",
|
||||
)
|
||||
@classmethod
|
||||
@@ -243,6 +251,12 @@ class PassConfig:
|
||||
"The fusion will be disabled."
|
||||
)
|
||||
self.fuse_act_padding = False
|
||||
if self.fuse_rope_kvcache and not current_platform.is_rocm():
|
||||
logger.warning_once(
|
||||
"KV cache fusion currently only enabled on ROCm. "
|
||||
"The fusion will be disabled."
|
||||
)
|
||||
self.fuse_rope_kvcache = False
|
||||
|
||||
|
||||
class DynamicShapesType(str, enum.Enum):
|
||||
@@ -824,6 +838,19 @@ class CompilationConfig:
|
||||
# TODO(zhuhaoran): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
if self.pass_config.fuse_rope_kvcache:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
|
||||
logger.warning(
|
||||
"Cannot use VLLM_ROCM_USE_AITER_TRITON_ROPE with "
|
||||
"fuse_rope_kvcache. Disabling fuse_rope_kvcache."
|
||||
)
|
||||
self.pass_config.fuse_rope_kvcache = False
|
||||
else:
|
||||
# TODO(Rohan138): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
|
||||
if (
|
||||
is_torch_equal_or_newer("2.9.0.dev")
|
||||
|
||||
@@ -1401,6 +1401,20 @@ class VllmConfig:
|
||||
"allreduce-rms fusion will be enabled for all num_tokens."
|
||||
)
|
||||
|
||||
if compilation_config.pass_config.fuse_rope_kvcache:
|
||||
max_token_num = (
|
||||
compilation_config.pass_config.rope_kvcache_fusion_max_token_num
|
||||
)
|
||||
if max_token_num is not None:
|
||||
if compile_range_end is not None and max_token_num < compile_range_end:
|
||||
computed_compile_ranges_split_points.append(max_token_num)
|
||||
else:
|
||||
logger.debug(
|
||||
"Max num batched tokens below rope+kvcache fusion threshold, "
|
||||
"rope+kvcache fusion enabled for num_tokens <= %d.",
|
||||
compile_range_end,
|
||||
)
|
||||
|
||||
if compilation_config.compile_ranges_split_points is not None:
|
||||
for x in compilation_config.compile_ranges_split_points:
|
||||
assert isinstance(x, int)
|
||||
|
||||
Reference in New Issue
Block a user