diff --git a/tests/compile/passes/test_functionalization.py b/tests/compile/passes/test_functionalization.py index e8da56b26..788ae7889 100644 --- a/tests/compile/passes/test_functionalization.py +++ b/tests/compile/passes/test_functionalization.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy + import pytest import torch -import vllm.envs as envs from tests.compile.backend import TestBackend from tests.utils import TestFP8Layer from vllm.compilation.passes.fusion.act_quant_fusion import ( @@ -31,6 +32,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op TEST_FP8 = current_platform.supports_fp8() FP8_DTYPE = current_platform.fp8_dtype() @@ -198,23 +200,82 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module): return [torch.ops.aten.slice_scatter.default] -MODELS = [ - TestSiluMul, - TestFusedAddRMSNorm, - TestRotaryEmbedding, - TestRotaryEmbeddingSliceScatter, -] +class TestFunctionWithMutatedArgsAndReturn(torch.nn.Module): + OP_REGISTERED = False + + def __init__(self): + super().__init__() + self.register_test_custom_op() + + @classmethod + def register_test_custom_op(cls): + if not cls.OP_REGISTERED: + + def function_with_mutated_args_and_return_impl( + x: torch.Tensor, + ) -> torch.Tensor: + ret = x + 1 + x.add_(2) + return ret + + def function_with_mutated_args_and_return_fake( + x: torch.Tensor, + ) -> torch.Tensor: + return torch.empty_like(x) + + direct_register_custom_op( + op_name="function_with_mutated_args_and_return", + op_func=function_with_mutated_args_and_return_impl, + mutates_args=["x"], + fake_impl=function_with_mutated_args_and_return_fake, + ) + + cls.OP_REGISTERED = True + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # Clone x to avoid mutating the original tensor + ret = torch.ops.vllm.function_with_mutated_args_and_return(x) + return x, ret + + def example_inputs(self, num_tokens=32): + hidden_states = torch.randn(num_tokens) + return (hidden_states,) + + def ops_in_model(self, do_fusion): + return [torch.ops.vllm.function_with_mutated_args_and_return.default] + + def ops_not_in_model(self): + return [] + + +MODELS_AND_DO_FUSION = { + TestSiluMul: [True, False], + TestFusedAddRMSNorm: [True, False], + TestRotaryEmbedding: [False], + TestRotaryEmbeddingSliceScatter: [False], + TestFunctionWithMutatedArgsAndReturn: [False], +} @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("model_class", MODELS) -@pytest.mark.parametrize("do_fusion", [True, False]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") +@pytest.mark.parametrize( + "model_class, do_fusion", + [ + (model_class, do_fusion) + for model_class, fusions in MODELS_AND_DO_FUSION.items() + for do_fusion in fusions + ], +) +@pytest.mark.skipif( + not current_platform.is_cuda_alike(), + reason="Only test on cuda and rocm platform", +) def test_fix_functionalization( model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype ): torch.set_default_device("cuda") torch.set_default_dtype(dtype) + torch.manual_seed(0) vllm_config = VllmConfig( model_config=ModelConfig(dtype=dtype), @@ -246,8 +307,14 @@ def test_fix_functionalization( backend_no_func = TestBackend(*passes) model = model_class() - torch.compile(model, backend=backend_func)(*model.example_inputs()) - torch.compile(model, backend=backend_no_func)(*model.example_inputs()) + inputs_func = model.example_inputs() + inputs_no_func = copy.deepcopy(inputs_func) + model_func = model_class() + model_no_func = copy.deepcopy(model_func) + model_func = torch.compile(model_func, backend=backend_func) + model_no_func = torch.compile(model_no_func, backend=backend_no_func) + model_func(*inputs_func) + model_no_func(*inputs_no_func) # check if the functionalization pass is applied for op in model.ops_in_model(do_fusion): @@ -265,3 +332,8 @@ def test_fix_functionalization( found[op] = True assert all(found[op] for op in model.ops_in_model(do_fusion)) assert all(not found.get(op) for op in model.ops_not_in_model()) + + # TODO (Rohan138): compare the outputs from model_func and model_no_func + # currently runs into errors while comparing `TestFusedAddRMSNorm` + # Linked issue: https://github.com/vllm-project/vllm/issues/34996 + # torch.testing.assert_close(outputs_func, outputs_no_func) diff --git a/tests/compile/passes/test_rope_kvcache_fusion.py b/tests/compile/passes/test_rope_kvcache_fusion.py new file mode 100644 index 000000000..d074d2a9e --- /dev/null +++ b/tests/compile/passes/test_rope_kvcache_fusion.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm.config +from tests.compile.backend import TestBackend +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata +from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops +from vllm.compilation.passes.fusion.matcher_utils import ROTARY_OP +from vllm.compilation.passes.fusion.rope_kvcache_fusion import RopeKVCacheFusionPass +from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass +from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass +from vllm.compilation.passes.utility.scatter_split_replace import ( + ScatterSplitReplacementPass, +) +from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass +from vllm.config import ( + CacheConfig, + CompilationConfig, + CompilationMode, + ModelConfig, + PassConfig, + VllmConfig, +) +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform +from vllm.v1.attention.backend import ( + AttentionBackend, + CommonAttentionMetadata, +) +from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.kv_cache_interface import AttentionSpec + +INDEX_SELECT_OP = torch.ops.aten.index.Tensor +VLLM_UNIFIED_KV_CACHE_UPDATE_OP = torch.ops.vllm.unified_kv_cache_update +FP8_DTYPE = current_platform.fp8_dtype() + + +class QKRoPEKVCacheTestModel(torch.nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + attn_backend: AttentionBackendEnum, + num_heads: int, + num_kv_heads: int, + head_size: int, + is_neox: bool, + dtype: torch.dtype, + device: torch.device, + prefix: str = "model.layers.0.self_attn.attn", + ): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_size = head_size + self.block_size = vllm_config.cache_config.block_size + self.q_size = num_heads * head_size + self.kv_size = num_kv_heads * head_size + self.is_neox = is_neox + self.dtype = dtype + self.device = device + self.layer_name = prefix + + self.rotary_emb = RotaryEmbedding( + head_size, + rotary_dim=head_size, + max_position_embeddings=4096, + base=10000, + is_neox_style=is_neox, + dtype=self.dtype, + ) + + # Whether to check for the RoPE custom op or component index_select + self.enable_rope_custom_op = self.rotary_emb.enabled() + + # Register layer metadata for the fusion pass via Attention. + self.attn = Attention( + num_heads=num_heads, + head_size=head_size, + scale=1.0 / head_size**0.5, + num_kv_heads=num_kv_heads, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + prefix=prefix, + attn_backend=attn_backend.get_class(), + ) + self.attn_backend: type[AttentionBackend] = self.attn.get_attn_backend() + assert not self.attn_backend.forward_includes_kv_cache_update, ( + f"Attention backend {self.attn_backend} does not support fuse_rope_kvcache." + ) + self.attn._k_scale = self.attn._k_scale.to(device) + self.attn._v_scale = self.attn._v_scale.to(device) + + kv_cache_dtype_str = vllm_config.cache_config.cache_dtype + self.kv_cache_dtype = ( + FP8_DTYPE if kv_cache_dtype_str.startswith("fp8") else self.dtype + ) + + # Initialize attn MetadataBuilder + self.builder = self.attn.attn_backend.get_builder_cls()( + kv_cache_spec=AttentionSpec( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=head_size, + dtype=self.kv_cache_dtype, + ), + layer_names=[self.attn.layer_name], + vllm_config=vllm_config, + device=device, + ) + + def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata: + """Initialize attention metadata.""" + # Create common attn metadata + batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size) + common_attn_metadata = create_common_attn_metadata( + batch_spec, self.block_size, self.device, arange_block_indices=True + ) + + max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size + num_blocks = batch_size * max_blocks + + # Fetch the attention backend and kv cache shape and stride order + attn_backend = self.attn.attn_backend + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size + ) + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + + kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + inv_order = [ + kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) + ] + + # Create dummy KV cache + raw_tensor = torch.zeros( + 2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size, + dtype=self.kv_cache_dtype, + device=self.device, + ) + raw_tensor = raw_tensor.view(kv_cache_shape) + kv_cache = raw_tensor.permute(*inv_order) + + self.attn.kv_cache = [kv_cache] + + # Build attn metadata + attn_metadata = self.builder.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + + return attn_metadata + + def forward( + self, qkv: torch.Tensor, positions: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Create copy so inplace ops do not modify the original tensors + qkv = qkv.clone() + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + + # Instead of a full forward pass, match only the KV cache update op here + q = q.view(-1, self.num_heads, self.head_size) + k = k.view(-1, self.num_kv_heads, self.head_size) + v = v.view(-1, self.num_kv_heads, self.head_size) + kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update( + k, v, self.layer_name + ) + return q, k, v, kv_cache_dummy_dep + + def ops_in_model_before(self) -> list[torch._ops.OpOverload]: + ops = [] + if self.enable_rope_custom_op: + ops.append(ROTARY_OP) + else: + ops.append(INDEX_SELECT_OP) + ops.append(torch.ops.vllm.unified_kv_cache_update.default) + return ops + + def ops_in_model_after(self) -> list[torch._ops.OpOverload]: + return [torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default] + + +@pytest.mark.parametrize( + "attn_backend", + [ + AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.ROCM_ATTN, + ], +) +@pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False]) +@pytest.mark.parametrize("num_heads", [64]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("is_neox", [True, False]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +@pytest.mark.skipif( + not is_aiter_found_and_supported(), + reason="Only test on ROCm with AITER installed and supported", +) +def test_rope_kvcache_fusion( + attn_backend: AttentionBackendEnum, + enable_rope_custom_op: bool, + num_heads: int, + num_kv_heads: int, + head_size: int, + block_size: int, + is_neox: bool, + dtype: torch.dtype, + kv_cache_dtype: str, + monkeypatch: pytest.MonkeyPatch, +): + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(0) + + custom_ops: list[str] = [] + if enable_rope_custom_op: + custom_ops.append("+rotary_embedding") + + vllm_config = VllmConfig( + model_config=ModelConfig(dtype=dtype), + cache_config=CacheConfig( + block_size=block_size, + cache_dtype=kv_cache_dtype, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops, + pass_config=PassConfig( + fuse_rope_kvcache=True, + eliminate_noops=True, + ), + ), + ) + + with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m: + m.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() + + model = QKRoPEKVCacheTestModel( + vllm_config=vllm_config, + attn_backend=attn_backend, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + is_neox=is_neox, + dtype=dtype, + device=torch.get_default_device(), + ) + + fusion_pass = RopeKVCacheFusionPass(vllm_config) + passes = [ + NoOpEliminationPass(vllm_config), + SplitCoalescingPass(vllm_config), + ScatterSplitReplacementPass(vllm_config), + fusion_pass, + PostCleanupPass(vllm_config), + ] + backend = TestBackend(*passes) + + T = 5 + + qkv = torch.randn( + T, num_heads * head_size + 2 * num_kv_heads * head_size, dtype=dtype + ) + pos = torch.arange(T, dtype=torch.long) + + qkv_unfused = qkv.clone() + pos_unfused = pos.clone() + + with set_forward_context(None, vllm_config): + forward_context = get_forward_context() + attn_metadata = model.build_attn_metadata(T) + forward_context.slot_mapping = { + model.layer_name: attn_metadata.slot_mapping + } + q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused) + attn_layer = forward_context.no_compile_layers[model.layer_name] + kv_cache_unfused = attn_layer.kv_cache[forward_context.virtual_engine] + del dummy + + torch._dynamo.mark_dynamic(qkv, 0) + torch._dynamo.mark_dynamic(pos, 0) + with set_forward_context(None, vllm_config): + model_fused = torch.compile(model, backend=backend) + forward_context = get_forward_context() + attn_metadata = model_fused.build_attn_metadata(T) + forward_context.slot_mapping = { + model.layer_name: attn_metadata.slot_mapping + } + q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos) + attn_layer = forward_context.no_compile_layers[model.layer_name] + kv_cache_fused = attn_layer.kv_cache[forward_context.virtual_engine] + del dummy + + assert fusion_pass.matched_count == 1 + + backend.check_before_ops(model.ops_in_model_before()) + backend.check_after_ops(model.ops_in_model_after()) + + if dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL) + # Cannot compare fp8_* directly here, cast to model dtype instead + torch.testing.assert_close( + kv_cache_unfused.view(dtype), + kv_cache_fused.view(dtype), + atol=ATOL, + rtol=RTOL, + ) diff --git a/tests/compile/passes/test_scatter_split_replace.py b/tests/compile/passes/test_scatter_split_replace.py new file mode 100644 index 000000000..659960896 --- /dev/null +++ b/tests/compile/passes/test_scatter_split_replace.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn as nn + +import vllm +from tests.compile.backend import TestBackend +from vllm.compilation.passes.utility.scatter_split_replace import ( + ScatterSplitReplacementPass, +) +from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass +from vllm.config import CompilationConfig, CompilationMode, VllmConfig +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + + +class ScatterSplitReplacementModel(nn.Module): + """Model with a rope+getitem+slice_scatter+split_with_sizes sequence.""" + + def __init__( + self, + num_heads: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + ): + super().__init__() + self.q_size = num_heads * head_size + self.kv_size = num_kv_heads * head_size + + self.rotary_emb = RotaryEmbedding( + head_size, + rotary_dim=head_size, + max_position_embeddings=4096, + base=10000, + is_neox_style=True, + dtype=dtype, + ) + + def forward(self, qkv: torch.Tensor, positions: torch.Tensor): + # Create copy so inplace ops do not modify the original tensors + qkv = qkv.clone() + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + q = q + 1 + k = k + 2 + v = v + 3 + return q, k, v + + def ops_in_model_before(self) -> list[torch._ops.OpOverload]: + return [ + torch.ops.aten.slice_scatter.default, + torch.ops.aten.split_with_sizes.default, + torch.ops.aten.getitem.default, + ] + + def ops_in_model_after(self) -> list[torch._ops.OpOverload]: + return [torch.ops.aten.getitem.default] + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_scatter_split_replace(dtype): + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(0) + + num_heads = 8 + num_kv_heads = 4 + head_size = 64 + + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=["+rotary_embedding"], + ), + ) + with vllm.config.set_current_vllm_config(vllm_config): + # ScatterSplitReplacementPass requires SplitCoalescingPass to be run before it + coalesce_pass = SplitCoalescingPass(vllm_config) + replace_pass = ScatterSplitReplacementPass(vllm_config) + passes = [coalesce_pass, replace_pass] + backend = TestBackend(*passes) + + model = ScatterSplitReplacementModel(num_heads, num_kv_heads, head_size, dtype) + + T = 5 + qkv = torch.randn( + T, num_heads * head_size + 2 * num_kv_heads * head_size, dtype=dtype + ) + pos = torch.arange(T, dtype=torch.long) + + qkv_eager = qkv.clone() + pos_eager = pos.clone() + result_eager = model(qkv_eager, pos_eager) + + torch._dynamo.mark_dynamic(qkv, 0) + torch._dynamo.mark_dynamic(pos, 0) + + model_compiled = torch.compile(model, backend=backend) + result_compiled = model_compiled(qkv, pos) + + for eager, compiled in zip(result_eager, result_compiled): + torch.testing.assert_close(eager, compiled) + + assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0 + assert backend.op_count(torch.ops.aten.split_with_sizes.default) == 1 diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index b6d918b41..8c3a62b6e 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -179,7 +179,7 @@ def create_and_prepopulate_kv_cache( block_table[i, :num_blocks_for_seq] = inv_perm[start:end] start_block_idx += num_blocks_for_seq - # Create a realistic slot mapping that corresponds to the block table + # Create a realistic slot mapping that corresponds to the block table for i in range(batch_size): token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i]) block_indices = token_offsets // block_size diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c544d2d3d..012a3f367 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1518,6 +1518,45 @@ class rocm_aiter_ops: query = query.view(query_shape) key = key.view(key_shape) + @staticmethod + def triton_rope_and_cache( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + layer_slot_mapping: torch.Tensor, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + flash_layout: bool, + apply_scale: bool, + ): + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache + + cos, sin = cos_sin_cache.chunk(2, dim=-1) + fused_qk_rope_reshape_and_cache( + query, + key, + value, + key_cache, + value_cache, + layer_slot_mapping, + positions, + cos, + sin, + k_scale, + v_scale, + is_neox, + flash_layout=flash_layout, + apply_scale=apply_scale, + q_out=query, + k_out=key, + output_zeros=False, + ) + @staticmethod def batched_gemm_a16wfp4( X: torch.Tensor, diff --git a/vllm/compilation/passes/fusion/rope_kvcache_fusion.py b/vllm/compilation/passes/fusion/rope_kvcache_fusion.py new file mode 100644 index 000000000..830a96407 --- /dev/null +++ b/vllm/compilation/passes/fusion/rope_kvcache_fusion.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +import torch._inductor.pattern_matcher as pm +from torch import fx +from torch._higher_order_ops import auto_functionalized +from torch._inductor.fx_passes.post_grad import view_to_reshape +from torch._inductor.pattern_matcher import PatternMatcherPass + +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.utils import Range +from vllm.logger import init_logger +from vllm.model_executor.layers.attention.attention import ( + Attention, + get_attention_context, +) +from vllm.utils.torch_utils import direct_register_custom_op + +from ..inductor_pass import enable_fake_mode +from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass +from .matcher_utils import ( + MatcherRotaryEmbedding, +) +from .rms_quant_fusion import ( + empty_bf16, + empty_i64, +) + +logger = init_logger(__name__) + + +def fused_rope_and_unified_kv_cache_update_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + layer_name: str = "", +) -> torch.Tensor: + """ + This impl fetches the KV cache and slot mapping from the forward context, + then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method. + It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`, + that is passed to unified_attention to signal a side effect and + the data dependency between them to ensure torch.compile preserves ordering. + """ + _, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name) + if layer_slot_mapping is not None: + attn_layer.impl.do_rope_and_kv_cache_update( + attn_layer, + query, + key, + value, + positions, + cos_sin_cache, + is_neox, + kv_cache, + layer_slot_mapping, + ) + + return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype) + + +def fused_rope_and_unified_kv_cache_update_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + layer_name: str = "", +) -> torch.Tensor: + return torch.empty(0, device=query.device, dtype=query.dtype) + + +direct_register_custom_op( + op_name="fused_rope_and_unified_kv_cache_update", + op_func=fused_rope_and_unified_kv_cache_update_impl, + mutates_args=["query", "key"], + fake_impl=fused_rope_and_unified_kv_cache_update_fake, +) + + +class RopeReshapeKVCachePattern: + """ + This pattern matches the following unfused inplace ops: + q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox) + kv_cache_dummy = unified_kv_cache_update(k, v, layer_name) + + and replaces it with the fused inplace op: + kv_cache_dummy = fused_rope_and_unified_kv_cache_update( + q, k, v, positions, cos_sin_cache, is_neox, layer_name + ) + """ + + FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default + + def __init__( + self, + layer: Attention, + is_neox: bool, + ) -> None: + self.layer_name = layer.layer_name + self.num_heads = layer.num_heads + self.num_kv_heads = layer.num_kv_heads + self.head_size = layer.head_size + self.head_size_v = layer.head_size_v + self.is_neox = is_neox + + self.q_size = self.num_heads * self.head_size + self.k_size = self.num_kv_heads * self.head_size + self.v_size = self.num_kv_heads * self.head_size_v + + self.rope_matcher = MatcherRotaryEmbedding( + is_neox=self.is_neox, + head_size=self.head_size, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + ) + + def get_inputs(self) -> list[torch.Tensor]: + # Sample inputs to help pattern tracing + T = 5 + L = 4096 + qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size) + positions = empty_i64(T) + cos_sin_cache = empty_bf16(L, self.head_size) + return [ + qkv, + positions, + cos_sin_cache, + ] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + qkv: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + q, k = self.rope_matcher(positions, q, k, cos_sin_cache) + q = q.view(-1, self.num_heads, self.head_size) + k = k.view(-1, self.num_kv_heads, self.head_size) + v = v.view(-1, self.num_kv_heads, self.head_size_v) + dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name) + return dummy, q, k, v + + def replacement( + qkv: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + q = q.view(-1, self.num_heads, self.head_size) + k = k.view(-1, self.num_kv_heads, self.head_size) + v = v.view(-1, self.num_kv_heads, self.head_size_v) + results = auto_functionalized( + self.FUSED_OP, + query=q, + key=k, + value=v, + positions=positions, + cos_sin_cache=cos_sin_cache, + is_neox=self.is_neox, + layer_name=self.layer_name, + ) + return results[0], results[1], results[2], v + + # NOTE: use view_to_reshape to unify view/reshape to simplify + # pattern and increase matching opportunities + def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule: + gm = pm.fwd_only(*args, **kwargs) + view_to_reshape(gm) + return gm + + pm.register_replacement( + pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass + ) + + +class RopeKVCacheFusionPass(VllmPatternMatcherPass): + """ + This pass fuses the rotary embedding and KV cache update operations + into a single fused kernel if available. + + It uses the pattern matcher and matches each layer manually, as strings + cannot be wildcarded. This also lets us check support on attention layers + upon registration instead of during pattern matching. + + This fusion eliminates the need for separate kernel launches and + intermediate memory operations between the RoPE and cache update steps. + """ + + @enable_fake_mode + def __init__(self, config: VllmConfig) -> None: + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="rope_kv_cache_fusion_pass" + ) + + cc = config.compilation_config + self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num + + attn_layers = get_layers_from_vllm_config(config, Attention) + for _, layer in attn_layers.items(): + if layer.impl.fused_rope_kvcache_supported(): + for is_neox in [True, False]: + RopeReshapeKVCachePattern( + layer=layer, + is_neox=is_neox, + ).register(self.patterns) + + self.dump_patterns(config, self.patterns) + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph) -> None: + self.matched_count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", self.matched_count) + + def is_applicable_for_range(self, compile_range: Range) -> bool: + # This pass works best for the small-batch decode setting. + # For large-batch e.g. prefill, it is better to use two separate kernels + # since they are compute bound and the fused kernels require further tuning. + return compile_range.end <= self.max_token_num + + def uuid(self) -> str: + return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern) diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index d9d3cc30b..70f86c8d2 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -28,7 +28,9 @@ if current_platform.is_cuda_alike(): from .fusion.attn_quant_fusion import AttnFusionPass from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass from .fusion.rms_quant_fusion import RMSNormQuantFusionPass + from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass from .fusion.sequence_parallelism import SequenceParallelismPass + from .utility.scatter_split_replace import ScatterSplitReplacementPass from .utility.split_coalescing import SplitCoalescingPass if current_platform.is_cuda(): @@ -136,6 +138,11 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc] if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled(): self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)] + if self.pass_config.fuse_rope_kvcache: + self.passes += [SplitCoalescingPass(config)] + self.passes += [ScatterSplitReplacementPass(config)] + self.passes += [RopeKVCacheFusionPass(config)] + if self.pass_config.fuse_attn_quant: self.passes += [AttnFusionPass(config)] diff --git a/vllm/compilation/passes/utility/fix_functionalization.py b/vllm/compilation/passes/utility/fix_functionalization.py index 55126a757..c7df5f92e 100644 --- a/vllm/compilation/passes/utility/fix_functionalization.py +++ b/vllm/compilation/passes/utility/fix_functionalization.py @@ -162,6 +162,24 @@ class FixFunctionalizationPass(VllmInductorPass): "position_ids", ) self.defunctionalize(graph, node, mutated_args=mutated_args, args=args) + elif ( + hasattr(torch.ops.vllm, "fused_rope_and_unified_kv_cache_update") + and at_target + == torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default + ): + mutated_args = { + 1: "query", + 2: "key", + } + self.defunctionalize(graph, node, mutated_args=mutated_args) + # only used for test_functionalization::TestFunctionWithMutatedArgsAndReturn + elif ( + hasattr(torch.ops.vllm, "function_with_mutated_args_and_return") + and at_target + == torch.ops.vllm.function_with_mutated_args_and_return.default + ): + mutated_args = {1: "x"} + self.defunctionalize(graph, node, mutated_args=mutated_args) else: continue # skip the count @@ -208,13 +226,20 @@ class FixFunctionalizationPass(VllmInductorPass): self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str] ) -> None: """ - Replace all getitem users of the auto-functionalized node with the + Replace mutated getitem users of the auto-functionalized node with the mutated arguments. :param node: The auto-functionalized node :param mutated_args: The mutated arguments, indexed by getitem index. If the value of an arg is a string, `node.kwargs[arg]` is used. """ for idx, user in self.getitem_users(node).items(): + # Some functionalized nodes may return both a result at getitem[0] + # as well as mutated args at getitem[1:...] + if idx == 0: + assert idx not in mutated_args, ( + f"result at getitem[0] should not be in mutated_args for {node}" + ) + continue arg = mutated_args[idx] arg = node.kwargs[arg] if isinstance(arg, str) else arg user.replace_all_uses_with(arg) @@ -257,10 +282,20 @@ class FixFunctionalizationPass(VllmInductorPass): with graph.inserting_before(node): function = node.args[0] if args is None: - graph.call_function(function, kwargs=node.kwargs) + fn_node = graph.call_function(function, kwargs=node.kwargs) else: # Args passed as strings refer to items in node.kwargs args = tuple( node.kwargs[arg] if isinstance(arg, str) else arg for arg in args ) - graph.call_function(function, args=args) + fn_node = graph.call_function(function, args=args) + + # If the function returns a value as well as mutating args inplace, + # the functionalized node will have a getitem[0] user that holds this value + # Replace getitem[0] user of the auto-functionalized node + # with the new defunctionalized node directly if it exists + users = self.getitem_users(node) + if 0 in users: + user = users[0] + user.replace_all_uses_with(fn_node) + self._remove(user) diff --git a/vllm/compilation/passes/utility/scatter_split_replace.py b/vllm/compilation/passes/utility/scatter_split_replace.py new file mode 100644 index 000000000..1826c07f8 --- /dev/null +++ b/vllm/compilation/passes/utility/scatter_split_replace.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Replace ``slice_scatter`` and ``split_with_sizes`` nodes with a single +assignment if there are no users for the inplace tensor written to by +the slice_scatter call. + +The inplace rotary_embedding custom op takes in mutable query and key inputs +that are split+getitem outputs of a single qkv tensor. +When functionalized, we fetch the rotated query and key from the functionalized op +using `getitem` calls. However, we also write to the qkv tensor inplace using a +`slice_scatter`, then split the inplace tensor to get the output tensors again. +Instead, if the inplace tensor has no subsequent users, we can just replace the +`slice_scatter` and `split_with_sizes` nodes with the `getitem` calls. + +This is already done in fix_functionalization::FixFunctionalizationPass, but +writing a custom pass for it before defunctionalization allows matching against the +qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass. +""" + +import operator + +import torch +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +from vllm.logger import init_logger + +from ..fx_utils import is_func +from ..vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +class ScatterSplitReplacementPass(VllmInductorPass): + """Replace getitem+slice_scatter+split nodes with a single getitem when + the inplace subtensor written to by the slice_scatter has no other users. + + Here's an example graph with q_size = 512, kv_size = 64: + split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1) + at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k)) + q = operator.getitem(at, 1) + k = operator.getitem(at, 2) + torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1) + torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1) + split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1) + q = operator.getitem(split_with_sizes_2, 0) + k = operator.getitem(split_with_sizes_2, 1) + v = operator.getitem(split_with_sizes_2, 2) + + After this pass, this sequence of nodes is replaced with: + split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1) + at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k)) + q = operator.getitem(at, 1) + k = operator.getitem(at, 2) + v = operator.getitem(split_with_sizes_1, 2) + """ + + @VllmInductorPass.time_and_log + def __call__(self, graph: fx.Graph) -> None: + count = 0 + + for node in graph.nodes: + if not is_func(node, auto_functionalized): + continue + + kwargs = node.kwargs + at_target = node.args[0] + + if at_target == torch.ops._C.rotary_embedding.default: + query = kwargs["query"] + key = kwargs["key"] + getitem_nodes = {} + for user in node.users: + if is_func(user, operator.getitem): + getitem_nodes[user.args[1]] = user + + if ( + is_func(query, operator.getitem) + and is_func(key, operator.getitem) + and query.args[0] == key.args[0] + and is_func(query.args[0], torch.ops.aten.split_with_sizes.default) + and all( + is_func(user, torch.ops.aten.slice_scatter.default) + for getitem_node in getitem_nodes.values() + for user in getitem_node.users + ) + ): + # Pattern where query and key are slices of a qkv tensor. + # While functionalized, results at [1] and [2] are scattered + # back into qkv, then split again to get query and key. + # If the inplace tensor has no other users, we can replace + # the slice_scatter+split nodes with the original results. + for user in getitem_nodes[1].users: + slice_scatter_1_node = user + if not is_func( + slice_scatter_1_node, torch.ops.aten.slice_scatter.default + ): + continue + + for user in getitem_nodes[2].users: + slice_scatter_2_node = user + if not is_func( + slice_scatter_2_node, torch.ops.aten.slice_scatter.default + ): + continue + + for user in slice_scatter_2_node.users: + split_node = user + if not is_func(split_node, torch.ops.aten.split_with_sizes.default): + continue + + split_getitem_users = {} + for user in split_node.users: + if is_func(user, operator.getitem): + split_getitem_users[user.args[1]] = user + + # Replace query node + split_getitem_users[0].replace_all_uses_with(getitem_nodes[1]) + graph.erase_node(split_getitem_users[0]) + # Replace key node + split_getitem_users[1].replace_all_uses_with(getitem_nodes[2]) + graph.erase_node(split_getitem_users[1]) + # Redirect value node to original qkv tensor + split_getitem_users[2].replace_input_with(split_node, query.args[0]) + + # Erase unused nodes + graph.erase_node(split_node) + graph.erase_node(slice_scatter_2_node) + graph.erase_node(slice_scatter_1_node) + + count += 1 + + logger.debug("Eliminated %d slice_scatter+split nodes", count) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index f1909ace6..b1f0779c7 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -127,6 +127,13 @@ class PassConfig: # ROCm/AITER specific fusions fuse_act_padding: bool = Field(default=None) """Fuse the custom RMSNorm + padding ops.""" + fuse_rope_kvcache: bool = Field(default=None) + """Fuse the QK rope + KV cache ops.""" + + rope_kvcache_fusion_max_token_num: int = 256 + """The threshold for ROCm AITER RoPE+KVCache fusion e.g. for small batch decode. + Larger batch sizes e.g. during prefill will use the unfused kernels. + """ fi_allreduce_fusion_max_size_mb: float | None = None """The threshold of the communicated tensor sizes under which @@ -198,6 +205,7 @@ class PassConfig: "fuse_gemm_comms", "fuse_allreduce_rms", "fuse_act_padding", + "fuse_rope_kvcache", mode="wrap", ) @classmethod @@ -243,6 +251,12 @@ class PassConfig: "The fusion will be disabled." ) self.fuse_act_padding = False + if self.fuse_rope_kvcache and not current_platform.is_rocm(): + logger.warning_once( + "KV cache fusion currently only enabled on ROCm. " + "The fusion will be disabled." + ) + self.fuse_rope_kvcache = False class DynamicShapesType(str, enum.Enum): @@ -824,6 +838,19 @@ class CompilationConfig: # TODO(zhuhaoran): support rope native forward match and remove this. # Linked issue: https://github.com/vllm-project/vllm/issues/28042 self.custom_ops.append("+rotary_embedding") + if self.pass_config.fuse_rope_kvcache: + from vllm._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_triton_rotary_embed_enabled(): + logger.warning( + "Cannot use VLLM_ROCM_USE_AITER_TRITON_ROPE with " + "fuse_rope_kvcache. Disabling fuse_rope_kvcache." + ) + self.pass_config.fuse_rope_kvcache = False + else: + # TODO(Rohan138): support rope native forward match and remove this. + # Linked issue: https://github.com/vllm-project/vllm/issues/28042 + self.custom_ops.append("+rotary_embedding") if ( is_torch_equal_or_newer("2.9.0.dev") diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 5db217b22..a9930c490 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1401,6 +1401,20 @@ class VllmConfig: "allreduce-rms fusion will be enabled for all num_tokens." ) + if compilation_config.pass_config.fuse_rope_kvcache: + max_token_num = ( + compilation_config.pass_config.rope_kvcache_fusion_max_token_num + ) + if max_token_num is not None: + if compile_range_end is not None and max_token_num < compile_range_end: + computed_compile_ranges_split_points.append(max_token_num) + else: + logger.debug( + "Max num batched tokens below rope+kvcache fusion threshold, " + "rope+kvcache fusion enabled for num_tokens <= %d.", + compile_range_end, + ) + if compilation_config.compile_ranges_split_points is not None: for x in compilation_config.compile_ranges_split_points: assert isinstance(x, int) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 8c3ff3cc4..ea627a93d 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -570,11 +570,11 @@ direct_register_custom_op( def get_attention_context( layer_name: str, -) -> tuple[Any, "Attention | MLAAttention", torch.Tensor]: +) -> tuple[Any, "Attention | MLAAttention", torch.Tensor, torch.Tensor]: """Extract attention context for a given layer. This helper function extracts the attention metadata, attention layer - instance, and KV cache tensor for a specific layer. + instance, KV cache tensor, and slot mapping for a specific layer. Args: layer_name: The name/identifier of the attention layer. @@ -585,6 +585,7 @@ def get_attention_context( no metadata available - attn_layer: The attention layer instance (Attention or MLAAttention) - kv_cache: The KV cache tensor for current virtual engine + - slot_mapping: The slot mapping for this specific layer Note: attn_metadata may be None, but attn_layer and kv_cache are always extracted from the forward context. @@ -593,9 +594,14 @@ def get_attention_context( attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[layer_name] - attn_layer = forward_context.no_compile_layers[layer_name] + attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name] kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] - return attn_metadata, attn_layer, kv_cache + slot_mapping = forward_context.slot_mapping + assert isinstance(slot_mapping, dict), ( + f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " + ) + layer_slot_mapping = slot_mapping.get(layer_name) + return attn_metadata, attn_layer, kv_cache, layer_slot_mapping @maybe_transfer_kv_layer @@ -605,7 +611,7 @@ def unified_attention( value: torch.Tensor, layer_name: str, ) -> torch.Tensor: - attn_metadata, self, kv_cache = get_attention_context(layer_name) + attn_metadata, self, kv_cache, _ = get_attention_context(layer_name) output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) return output @@ -636,15 +642,7 @@ def unified_kv_cache_update( Returns a dummy that is passed to unified_attention to signal a side effect and the data dependency between them to ensure torch.compile preserves ordering. """ - forward_context = get_forward_context() - attn_layer = forward_context.no_compile_layers[layer_name] - kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] - - slot_mapping = forward_context.slot_mapping - assert isinstance(slot_mapping, dict), ( - f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " - ) - layer_slot_mapping = slot_mapping.get(layer_name) + _, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name) if layer_slot_mapping is not None: assert hasattr(attn_layer.impl, "do_kv_cache_update"), ( f"{attn_layer.impl.__class__.__name__} does not support kv cache update" @@ -691,7 +689,7 @@ def unified_attention_with_output( # that ensures torch.compile preserves ordering between KV cache update and # attention forward. del kv_cache_dummy_dep - attn_metadata, self, kv_cache = get_attention_context(layer_name) + attn_metadata, self, kv_cache, _ = get_attention_context(layer_name) self.impl.forward( self, diff --git a/vllm/model_executor/layers/attention/kv_transfer_utils.py b/vllm/model_executor/layers/attention/kv_transfer_utils.py index 9ee6b4d0f..4afc5ccb1 100644 --- a/vllm/model_executor/layers/attention/kv_transfer_utils.py +++ b/vllm/model_executor/layers/attention/kv_transfer_utils.py @@ -40,8 +40,8 @@ def maybe_transfer_kv_layer(func: Callable) -> Callable: layer_name: str = args[layer_name_index] - # Extract attention context (layer-specific metadata, layer, and kv_cache) - attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name) + # Extract attention context (metadata, layer, kv_cache, layer_slot_mapping) + attn_metadata, _, kv_cache, _ = get_attention_context(layer_name) connector = get_kv_transfer_group() if attn_metadata is None or not connector.has_connector_metadata(): return func(*args, **kwargs) diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index faebad596..d444e20da 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -828,7 +828,7 @@ def unified_mla_attention( k_pe: torch.Tensor, layer_name: str, ) -> torch.Tensor: - attn_metadata, layer, kv_cache = get_attention_context(layer_name) + attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata) return output @@ -862,7 +862,7 @@ def unified_mla_attention_with_output( output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: - attn_metadata, layer, kv_cache = get_attention_context(layer_name) + attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) layer.forward_impl( q, kv_c_normed, diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 9c004d772..864beda10 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -723,6 +723,33 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]): """ return False + def fused_rope_kvcache_supported(self): + """ + Does this attention implementation support RoPE+KVCache fusion. + This is used by the RopeKVCacheFusionPass to only fuse the RoPE ops + with the KV cache update for implementations that support it. + """ + return False + + def do_rope_and_kv_cache_update( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + kv_cache: torch.Tensor, + layer_slot_mapping: torch.Tensor, + ): + """ + If `fused_rope_kvcache_supported` returns True, this method will be called + by torch.ops.vllm.fused_rope_and_unified_kv_cache_update + to perform the inplace RoPE and KV cache update. + """ + raise NotImplementedError + class MLAAttentionImpl(AttentionImplBase[T], Generic[T]): """MLA attention implementation with forward_mqa and forward_mha methods.""" diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 0c1e1b5e0..b9ca39d8e 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -11,7 +11,6 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.attention.attention import get_attention_context from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import get_cu_count @@ -1290,11 +1289,6 @@ class AiterFlashAttentionImpl(AttentionImpl): kv_cache: torch.Tensor, slot_mapping: torch.Tensor, ): - attn_metadata, _, _ = get_attention_context(layer.layer_name) - if attn_metadata is None: - # Profiling run. - return - key_cache, value_cache = kv_cache.unbind(0) # key and value may be None in the case of cross attention. They are @@ -1303,45 +1297,40 @@ class AiterFlashAttentionImpl(AttentionImpl): if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype()) - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping - # is not padded. However, we don't need to do - # key[:num_actual_tokens] and value[:num_actual_tokens] because - # the reshape_and_cache_flash op uses the slot_mapping's shape - # to determine the number of actual tokens. - if rocm_aiter_ops.is_shuffle_kv_cache_enabled(): - # We may calculate per token quant scale in - # reshape_and_cache_shuffle_triton which might differ from - # vllm's style when shuffle layout is used. - k_scale = attn_metadata.k_scale - v_scale = attn_metadata.v_scale - assert k_scale is not None and v_scale is not None, ( - "k_scale and v_scale are required for shuffled update" - ) - reshape_and_cache_shuffle_triton( - key, - value, - key_cache, - value_cache, - slot_mapping, - self.kv_cache_dtype, - k_scale, - v_scale, - ) - else: - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping + # is not padded. However, we don't need to do + # key[:num_actual_tokens] and value[:num_actual_tokens] because + # the reshape_and_cache_flash op uses the slot_mapping's shape + # to determine the number of actual tokens. + if rocm_aiter_ops.is_shuffle_kv_cache_enabled(): + # We may calculate per token quant scale in + # reshape_and_cache_shuffle_triton which might differ from + # vllm's style when shuffle layout is used. + k_scale = layer._k_scale + v_scale = layer._v_scale + assert k_scale is not None and v_scale is not None, ( + "k_scale and v_scale are required for shuffled update" + ) + reshape_and_cache_shuffle_triton( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + k_scale, + v_scale, + ) + else: + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index 3d8a660c9..db6fd97c9 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -5,6 +5,7 @@ import torch from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, @@ -207,3 +208,42 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): layer._k_scale, layer._v_scale, ) + + def fused_rope_kvcache_supported(self): + return rocm_aiter_ops.is_enabled() + + def do_rope_and_kv_cache_update( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + kv_cache: torch.Tensor, + layer_slot_mapping: torch.Tensor, + ): + key_cache, value_cache = kv_cache.unbind(0) + flash_layout = True + + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + + rocm_aiter_ops.triton_rope_and_cache( + query, + key, + value, + positions, + cos_sin_cache, + is_neox, + key_cache, + value_cache, + layer_slot_mapping, + layer._k_scale, + layer._v_scale, + flash_layout, + is_fp8_kv_cache, + ) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 0b9889c13..d72293dec 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -7,6 +7,7 @@ from typing import ClassVar import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -415,3 +416,46 @@ class RocmAttentionImpl(AttentionImpl): layer._k_scale, layer._v_scale, ) + + def fused_rope_kvcache_supported(self): + return rocm_aiter_ops.is_enabled() + + def do_rope_and_kv_cache_update( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + kv_cache: torch.Tensor, + layer_slot_mapping: torch.Tensor, + ): + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, + layer.num_kv_heads, # type: ignore[attr-defined] + layer.head_size, # type: ignore[attr-defined] + ) + flash_layout = False + + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + + rocm_aiter_ops.triton_rope_and_cache( + query, + key, + value, + positions, + cos_sin_cache, + is_neox, + key_cache, + value_cache, + layer_slot_mapping, + layer._k_scale, + layer._v_scale, + flash_layout, + is_fp8_kv_cache, + ) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index c0987dbe4..953d7b3c4 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -7,6 +7,7 @@ from typing import ClassVar import torch +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import CUDAGraphMode, VllmConfig from vllm.config.cache import CacheDType from vllm.logger import init_logger @@ -596,3 +597,42 @@ class TritonAttentionImpl(AttentionImpl): layer._k_scale, layer._v_scale, ) + + def fused_rope_kvcache_supported(self): + return rocm_aiter_ops.is_enabled() + + def do_rope_and_kv_cache_update( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + kv_cache: torch.Tensor, + layer_slot_mapping: torch.Tensor, + ): + key_cache, value_cache = kv_cache.unbind(1) + flash_layout = True + + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + + rocm_aiter_ops.triton_rope_and_cache( + query, + key, + value, + positions, + cos_sin_cache, + is_neox, + key_cache, + value_cache, + layer_slot_mapping, + layer._k_scale, + layer._v_scale, + flash_layout, + is_fp8_kv_cache, + )