[ROCm] AITER fused RoPE+KVCache (#33443)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com> Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Co-authored-by: charlifu <charlifu@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.utils import TestFP8Layer
|
||||
from vllm.compilation.passes.fusion.act_quant_fusion import (
|
||||
@@ -31,6 +32,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
TEST_FP8 = current_platform.supports_fp8()
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -198,23 +200,82 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
return [torch.ops.aten.slice_scatter.default]
|
||||
|
||||
|
||||
MODELS = [
|
||||
TestSiluMul,
|
||||
TestFusedAddRMSNorm,
|
||||
TestRotaryEmbedding,
|
||||
TestRotaryEmbeddingSliceScatter,
|
||||
]
|
||||
class TestFunctionWithMutatedArgsAndReturn(torch.nn.Module):
|
||||
OP_REGISTERED = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_test_custom_op()
|
||||
|
||||
@classmethod
|
||||
def register_test_custom_op(cls):
|
||||
if not cls.OP_REGISTERED:
|
||||
|
||||
def function_with_mutated_args_and_return_impl(
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ret = x + 1
|
||||
x.add_(2)
|
||||
return ret
|
||||
|
||||
def function_with_mutated_args_and_return_fake(
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="function_with_mutated_args_and_return",
|
||||
op_func=function_with_mutated_args_and_return_impl,
|
||||
mutates_args=["x"],
|
||||
fake_impl=function_with_mutated_args_and_return_fake,
|
||||
)
|
||||
|
||||
cls.OP_REGISTERED = True
|
||||
|
||||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Clone x to avoid mutating the original tensor
|
||||
ret = torch.ops.vllm.function_with_mutated_args_and_return(x)
|
||||
return x, ret
|
||||
|
||||
def example_inputs(self, num_tokens=32):
|
||||
hidden_states = torch.randn(num_tokens)
|
||||
return (hidden_states,)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
return [torch.ops.vllm.function_with_mutated_args_and_return.default]
|
||||
|
||||
def ops_not_in_model(self):
|
||||
return []
|
||||
|
||||
|
||||
MODELS_AND_DO_FUSION = {
|
||||
TestSiluMul: [True, False],
|
||||
TestFusedAddRMSNorm: [True, False],
|
||||
TestRotaryEmbedding: [False],
|
||||
TestRotaryEmbeddingSliceScatter: [False],
|
||||
TestFunctionWithMutatedArgsAndReturn: [False],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("model_class", MODELS)
|
||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
|
||||
@pytest.mark.parametrize(
|
||||
"model_class, do_fusion",
|
||||
[
|
||||
(model_class, do_fusion)
|
||||
for model_class, fusions in MODELS_AND_DO_FUSION.items()
|
||||
for do_fusion in fusions
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="Only test on cuda and rocm platform",
|
||||
)
|
||||
def test_fix_functionalization(
|
||||
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
@@ -246,8 +307,14 @@ def test_fix_functionalization(
|
||||
backend_no_func = TestBackend(*passes)
|
||||
|
||||
model = model_class()
|
||||
torch.compile(model, backend=backend_func)(*model.example_inputs())
|
||||
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
|
||||
inputs_func = model.example_inputs()
|
||||
inputs_no_func = copy.deepcopy(inputs_func)
|
||||
model_func = model_class()
|
||||
model_no_func = copy.deepcopy(model_func)
|
||||
model_func = torch.compile(model_func, backend=backend_func)
|
||||
model_no_func = torch.compile(model_no_func, backend=backend_no_func)
|
||||
model_func(*inputs_func)
|
||||
model_no_func(*inputs_no_func)
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
@@ -265,3 +332,8 @@ def test_fix_functionalization(
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model(do_fusion))
|
||||
assert all(not found.get(op) for op in model.ops_not_in_model())
|
||||
|
||||
# TODO (Rohan138): compare the outputs from model_func and model_no_func
|
||||
# currently runs into errors while comparing `TestFusedAddRMSNorm`
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/34996
|
||||
# torch.testing.assert_close(outputs_func, outputs_no_func)
|
||||
|
||||
325
tests/compile/passes/test_rope_kvcache_fusion.py
Normal file
325
tests/compile/passes/test_rope_kvcache_fusion.py
Normal file
@@ -0,0 +1,325 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.config
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
|
||||
from vllm.compilation.passes.fusion.matcher_utils import ROTARY_OP
|
||||
from vllm.compilation.passes.fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
|
||||
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.compilation.passes.utility.scatter_split_replace import (
|
||||
ScatterSplitReplacementPass,
|
||||
)
|
||||
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
INDEX_SELECT_OP = torch.ops.aten.index.Tensor
|
||||
VLLM_UNIFIED_KV_CACHE_UPDATE_OP = torch.ops.vllm.unified_kv_cache_update
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class QKRoPEKVCacheTestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
attn_backend: AttentionBackendEnum,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
is_neox: bool,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
prefix: str = "model.layers.0.self_attn.attn",
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_size = head_size
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.q_size = num_heads * head_size
|
||||
self.kv_size = num_kv_heads * head_size
|
||||
self.is_neox = is_neox
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.layer_name = prefix
|
||||
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim=head_size,
|
||||
max_position_embeddings=4096,
|
||||
base=10000,
|
||||
is_neox_style=is_neox,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# Whether to check for the RoPE custom op or component index_select
|
||||
self.enable_rope_custom_op = self.rotary_emb.enabled()
|
||||
|
||||
# Register layer metadata for the fusion pass via Attention.
|
||||
self.attn = Attention(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=1.0 / head_size**0.5,
|
||||
num_kv_heads=num_kv_heads,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=prefix,
|
||||
attn_backend=attn_backend.get_class(),
|
||||
)
|
||||
self.attn_backend: type[AttentionBackend] = self.attn.get_attn_backend()
|
||||
assert not self.attn_backend.forward_includes_kv_cache_update, (
|
||||
f"Attention backend {self.attn_backend} does not support fuse_rope_kvcache."
|
||||
)
|
||||
self.attn._k_scale = self.attn._k_scale.to(device)
|
||||
self.attn._v_scale = self.attn._v_scale.to(device)
|
||||
|
||||
kv_cache_dtype_str = vllm_config.cache_config.cache_dtype
|
||||
self.kv_cache_dtype = (
|
||||
FP8_DTYPE if kv_cache_dtype_str.startswith("fp8") else self.dtype
|
||||
)
|
||||
|
||||
# Initialize attn MetadataBuilder
|
||||
self.builder = self.attn.attn_backend.get_builder_cls()(
|
||||
kv_cache_spec=AttentionSpec(
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
),
|
||||
layer_names=[self.attn.layer_name],
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata:
|
||||
"""Initialize attention metadata."""
|
||||
# Create common attn metadata
|
||||
batch_spec = BatchSpec(seq_lens=[1] * batch_size, query_lens=[1] * batch_size)
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, self.block_size, self.device, arange_block_indices=True
|
||||
)
|
||||
|
||||
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
|
||||
num_blocks = batch_size * max_blocks
|
||||
|
||||
# Fetch the attention backend and kv cache shape and stride order
|
||||
attn_backend = self.attn.attn_backend
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_kv_heads, self.head_size
|
||||
)
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
|
||||
|
||||
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
|
||||
# Create dummy KV cache
|
||||
raw_tensor = torch.zeros(
|
||||
2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
raw_tensor = raw_tensor.view(kv_cache_shape)
|
||||
kv_cache = raw_tensor.permute(*inv_order)
|
||||
|
||||
self.attn.kv_cache = [kv_cache]
|
||||
|
||||
# Build attn metadata
|
||||
attn_metadata = self.builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def forward(
|
||||
self, qkv: torch.Tensor, positions: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Create copy so inplace ops do not modify the original tensors
|
||||
qkv = qkv.clone()
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
# Instead of a full forward pass, match only the KV cache update op here
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size)
|
||||
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
|
||||
k, v, self.layer_name
|
||||
)
|
||||
return q, k, v, kv_cache_dummy_dep
|
||||
|
||||
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
|
||||
ops = []
|
||||
if self.enable_rope_custom_op:
|
||||
ops.append(ROTARY_OP)
|
||||
else:
|
||||
ops.append(INDEX_SELECT_OP)
|
||||
ops.append(torch.ops.vllm.unified_kv_cache_update.default)
|
||||
return ops
|
||||
|
||||
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
|
||||
return [torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend",
|
||||
[
|
||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.ROCM_ATTN,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False])
|
||||
@pytest.mark.parametrize("num_heads", [64])
|
||||
@pytest.mark.parametrize("num_kv_heads", [8])
|
||||
@pytest.mark.parametrize("head_size", [64])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("is_neox", [True, False])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
@pytest.mark.skipif(
|
||||
not is_aiter_found_and_supported(),
|
||||
reason="Only test on ROCm with AITER installed and supported",
|
||||
)
|
||||
def test_rope_kvcache_fusion(
|
||||
attn_backend: AttentionBackendEnum,
|
||||
enable_rope_custom_op: bool,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
is_neox: bool,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
custom_ops: list[str] = []
|
||||
if enable_rope_custom_op:
|
||||
custom_ops.append("+rotary_embedding")
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
cache_config=CacheConfig(
|
||||
block_size=block_size,
|
||||
cache_dtype=kv_cache_dtype,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(
|
||||
fuse_rope_kvcache=True,
|
||||
eliminate_noops=True,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
with vllm.config.set_current_vllm_config(vllm_config), monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
model = QKRoPEKVCacheTestModel(
|
||||
vllm_config=vllm_config,
|
||||
attn_backend=attn_backend,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
is_neox=is_neox,
|
||||
dtype=dtype,
|
||||
device=torch.get_default_device(),
|
||||
)
|
||||
|
||||
fusion_pass = RopeKVCacheFusionPass(vllm_config)
|
||||
passes = [
|
||||
NoOpEliminationPass(vllm_config),
|
||||
SplitCoalescingPass(vllm_config),
|
||||
ScatterSplitReplacementPass(vllm_config),
|
||||
fusion_pass,
|
||||
PostCleanupPass(vllm_config),
|
||||
]
|
||||
backend = TestBackend(*passes)
|
||||
|
||||
T = 5
|
||||
|
||||
qkv = torch.randn(
|
||||
T, num_heads * head_size + 2 * num_kv_heads * head_size, dtype=dtype
|
||||
)
|
||||
pos = torch.arange(T, dtype=torch.long)
|
||||
|
||||
qkv_unfused = qkv.clone()
|
||||
pos_unfused = pos.clone()
|
||||
|
||||
with set_forward_context(None, vllm_config):
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = model.build_attn_metadata(T)
|
||||
forward_context.slot_mapping = {
|
||||
model.layer_name: attn_metadata.slot_mapping
|
||||
}
|
||||
q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused)
|
||||
attn_layer = forward_context.no_compile_layers[model.layer_name]
|
||||
kv_cache_unfused = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
del dummy
|
||||
|
||||
torch._dynamo.mark_dynamic(qkv, 0)
|
||||
torch._dynamo.mark_dynamic(pos, 0)
|
||||
with set_forward_context(None, vllm_config):
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = model_fused.build_attn_metadata(T)
|
||||
forward_context.slot_mapping = {
|
||||
model.layer_name: attn_metadata.slot_mapping
|
||||
}
|
||||
q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos)
|
||||
attn_layer = forward_context.no_compile_layers[model.layer_name]
|
||||
kv_cache_fused = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
del dummy
|
||||
|
||||
assert fusion_pass.matched_count == 1
|
||||
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(q_unfused, q_fused, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(k_unfused, k_fused, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(v_unfused, v_fused, atol=ATOL, rtol=RTOL)
|
||||
# Cannot compare fp8_* directly here, cast to model dtype instead
|
||||
torch.testing.assert_close(
|
||||
kv_cache_unfused.view(dtype),
|
||||
kv_cache_fused.view(dtype),
|
||||
atol=ATOL,
|
||||
rtol=RTOL,
|
||||
)
|
||||
107
tests/compile/passes/test_scatter_split_replace.py
Normal file
107
tests/compile/passes/test_scatter_split_replace.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm
|
||||
from tests.compile.backend import TestBackend
|
||||
from vllm.compilation.passes.utility.scatter_split_replace import (
|
||||
ScatterSplitReplacementPass,
|
||||
)
|
||||
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
|
||||
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
class ScatterSplitReplacementModel(nn.Module):
|
||||
"""Model with a rope+getitem+slice_scatter+split_with_sizes sequence."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
self.q_size = num_heads * head_size
|
||||
self.kv_size = num_kv_heads * head_size
|
||||
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim=head_size,
|
||||
max_position_embeddings=4096,
|
||||
base=10000,
|
||||
is_neox_style=True,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, qkv: torch.Tensor, positions: torch.Tensor):
|
||||
# Create copy so inplace ops do not modify the original tensors
|
||||
qkv = qkv.clone()
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
q = q + 1
|
||||
k = k + 2
|
||||
v = v + 3
|
||||
return q, k, v
|
||||
|
||||
def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
|
||||
return [
|
||||
torch.ops.aten.slice_scatter.default,
|
||||
torch.ops.aten.split_with_sizes.default,
|
||||
torch.ops.aten.getitem.default,
|
||||
]
|
||||
|
||||
def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
|
||||
return [torch.ops.aten.getitem.default]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_scatter_split_replace(dtype):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
num_heads = 8
|
||||
num_kv_heads = 4
|
||||
head_size = 64
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+rotary_embedding"],
|
||||
),
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# ScatterSplitReplacementPass requires SplitCoalescingPass to be run before it
|
||||
coalesce_pass = SplitCoalescingPass(vllm_config)
|
||||
replace_pass = ScatterSplitReplacementPass(vllm_config)
|
||||
passes = [coalesce_pass, replace_pass]
|
||||
backend = TestBackend(*passes)
|
||||
|
||||
model = ScatterSplitReplacementModel(num_heads, num_kv_heads, head_size, dtype)
|
||||
|
||||
T = 5
|
||||
qkv = torch.randn(
|
||||
T, num_heads * head_size + 2 * num_kv_heads * head_size, dtype=dtype
|
||||
)
|
||||
pos = torch.arange(T, dtype=torch.long)
|
||||
|
||||
qkv_eager = qkv.clone()
|
||||
pos_eager = pos.clone()
|
||||
result_eager = model(qkv_eager, pos_eager)
|
||||
|
||||
torch._dynamo.mark_dynamic(qkv, 0)
|
||||
torch._dynamo.mark_dynamic(pos, 0)
|
||||
|
||||
model_compiled = torch.compile(model, backend=backend)
|
||||
result_compiled = model_compiled(qkv, pos)
|
||||
|
||||
for eager, compiled in zip(result_eager, result_compiled):
|
||||
torch.testing.assert_close(eager, compiled)
|
||||
|
||||
assert backend.op_count(torch.ops.aten.slice_scatter.default) == 0
|
||||
assert backend.op_count(torch.ops.aten.split_with_sizes.default) == 1
|
||||
@@ -179,7 +179,7 @@ def create_and_prepopulate_kv_cache(
|
||||
block_table[i, :num_blocks_for_seq] = inv_perm[start:end]
|
||||
start_block_idx += num_blocks_for_seq
|
||||
|
||||
# Create a realistic slot mapping that corresponds to the block table
|
||||
# Create a realistic slot mapping that corresponds to the block table
|
||||
for i in range(batch_size):
|
||||
token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i])
|
||||
block_indices = token_offsets // block_size
|
||||
|
||||
@@ -1518,6 +1518,45 @@ class rocm_aiter_ops:
|
||||
query = query.view(query_shape)
|
||||
key = key.view(key_shape)
|
||||
|
||||
@staticmethod
|
||||
def triton_rope_and_cache(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
flash_layout: bool,
|
||||
apply_scale: bool,
|
||||
):
|
||||
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
|
||||
|
||||
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||
fused_qk_rope_reshape_and_cache(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
layer_slot_mapping,
|
||||
positions,
|
||||
cos,
|
||||
sin,
|
||||
k_scale,
|
||||
v_scale,
|
||||
is_neox,
|
||||
flash_layout=flash_layout,
|
||||
apply_scale=apply_scale,
|
||||
q_out=query,
|
||||
k_out=key,
|
||||
output_zeros=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def batched_gemm_a16wfp4(
|
||||
X: torch.Tensor,
|
||||
|
||||
230
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
Normal file
230
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.attention import (
|
||||
Attention,
|
||||
get_attention_context,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherRotaryEmbedding,
|
||||
)
|
||||
from .rms_quant_fusion import (
|
||||
empty_bf16,
|
||||
empty_i64,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def fused_rope_and_unified_kv_cache_update_impl(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
layer_name: str = "",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This impl fetches the KV cache and slot mapping from the forward context,
|
||||
then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method.
|
||||
It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`,
|
||||
that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
attn_layer.impl.do_rope_and_kv_cache_update(
|
||||
attn_layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
kv_cache,
|
||||
layer_slot_mapping,
|
||||
)
|
||||
|
||||
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
||||
|
||||
|
||||
def fused_rope_and_unified_kv_cache_update_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
layer_name: str = "",
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(0, device=query.device, dtype=query.dtype)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_rope_and_unified_kv_cache_update",
|
||||
op_func=fused_rope_and_unified_kv_cache_update_impl,
|
||||
mutates_args=["query", "key"],
|
||||
fake_impl=fused_rope_and_unified_kv_cache_update_fake,
|
||||
)
|
||||
|
||||
|
||||
class RopeReshapeKVCachePattern:
|
||||
"""
|
||||
This pattern matches the following unfused inplace ops:
|
||||
q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox)
|
||||
kv_cache_dummy = unified_kv_cache_update(k, v, layer_name)
|
||||
|
||||
and replaces it with the fused inplace op:
|
||||
kv_cache_dummy = fused_rope_and_unified_kv_cache_update(
|
||||
q, k, v, positions, cos_sin_cache, is_neox, layer_name
|
||||
)
|
||||
"""
|
||||
|
||||
FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.num_kv_heads = layer.num_kv_heads
|
||||
self.head_size = layer.head_size
|
||||
self.head_size_v = layer.head_size_v
|
||||
self.is_neox = is_neox
|
||||
|
||||
self.q_size = self.num_heads * self.head_size
|
||||
self.k_size = self.num_kv_heads * self.head_size
|
||||
self.v_size = self.num_kv_heads * self.head_size_v
|
||||
|
||||
self.rope_matcher = MatcherRotaryEmbedding(
|
||||
is_neox=self.is_neox,
|
||||
head_size=self.head_size,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
# Sample inputs to help pattern tracing
|
||||
T = 5
|
||||
L = 4096
|
||||
qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
|
||||
positions = empty_i64(T)
|
||||
cos_sin_cache = empty_bf16(L, self.head_size)
|
||||
return [
|
||||
qkv,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
|
||||
return dummy, q, k, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
results = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
positions=positions,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
layer_name=self.layer_name,
|
||||
)
|
||||
return results[0], results[1], results[2], v
|
||||
|
||||
# NOTE: use view_to_reshape to unify view/reshape to simplify
|
||||
# pattern and increase matching opportunities
|
||||
def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
|
||||
gm = pm.fwd_only(*args, **kwargs)
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RopeKVCacheFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses the rotary embedding and KV cache update operations
|
||||
into a single fused kernel if available.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
This fusion eliminates the need for separate kernel launches and
|
||||
intermediate memory operations between the RoPE and cache update steps.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rope_kv_cache_fusion_pass"
|
||||
)
|
||||
|
||||
cc = config.compilation_config
|
||||
self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for _, layer in attn_layers.items():
|
||||
if layer.impl.fused_rope_kvcache_supported():
|
||||
for is_neox in [True, False]:
|
||||
RopeReshapeKVCachePattern(
|
||||
layer=layer,
|
||||
is_neox=is_neox,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass works best for the small-batch decode setting.
|
||||
# For large-batch e.g. prefill, it is better to use two separate kernels
|
||||
# since they are compute bound and the fused kernels require further tuning.
|
||||
return compile_range.end <= self.max_token_num
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)
|
||||
@@ -28,7 +28,9 @@ if current_platform.is_cuda_alike():
|
||||
from .fusion.attn_quant_fusion import AttnFusionPass
|
||||
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||
from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
|
||||
from .fusion.sequence_parallelism import SequenceParallelismPass
|
||||
from .utility.scatter_split_replace import ScatterSplitReplacementPass
|
||||
from .utility.split_coalescing import SplitCoalescingPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
@@ -136,6 +138,11 @@ class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_rope_kvcache:
|
||||
self.passes += [SplitCoalescingPass(config)]
|
||||
self.passes += [ScatterSplitReplacementPass(config)]
|
||||
self.passes += [RopeKVCacheFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_attn_quant:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
|
||||
@@ -162,6 +162,24 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
"position_ids",
|
||||
)
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
|
||||
elif (
|
||||
hasattr(torch.ops.vllm, "fused_rope_and_unified_kv_cache_update")
|
||||
and at_target
|
||||
== torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
|
||||
):
|
||||
mutated_args = {
|
||||
1: "query",
|
||||
2: "key",
|
||||
}
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args)
|
||||
# only used for test_functionalization::TestFunctionWithMutatedArgsAndReturn
|
||||
elif (
|
||||
hasattr(torch.ops.vllm, "function_with_mutated_args_and_return")
|
||||
and at_target
|
||||
== torch.ops.vllm.function_with_mutated_args_and_return.default
|
||||
):
|
||||
mutated_args = {1: "x"}
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args)
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
@@ -208,13 +226,20 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
|
||||
) -> None:
|
||||
"""
|
||||
Replace all getitem users of the auto-functionalized node with the
|
||||
Replace mutated getitem users of the auto-functionalized node with the
|
||||
mutated arguments.
|
||||
:param node: The auto-functionalized node
|
||||
:param mutated_args: The mutated arguments, indexed by getitem index.
|
||||
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
||||
"""
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
# Some functionalized nodes may return both a result at getitem[0]
|
||||
# as well as mutated args at getitem[1:...]
|
||||
if idx == 0:
|
||||
assert idx not in mutated_args, (
|
||||
f"result at getitem[0] should not be in mutated_args for {node}"
|
||||
)
|
||||
continue
|
||||
arg = mutated_args[idx]
|
||||
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
user.replace_all_uses_with(arg)
|
||||
@@ -257,10 +282,20 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
with graph.inserting_before(node):
|
||||
function = node.args[0]
|
||||
if args is None:
|
||||
graph.call_function(function, kwargs=node.kwargs)
|
||||
fn_node = graph.call_function(function, kwargs=node.kwargs)
|
||||
else:
|
||||
# Args passed as strings refer to items in node.kwargs
|
||||
args = tuple(
|
||||
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
|
||||
)
|
||||
graph.call_function(function, args=args)
|
||||
fn_node = graph.call_function(function, args=args)
|
||||
|
||||
# If the function returns a value as well as mutating args inplace,
|
||||
# the functionalized node will have a getitem[0] user that holds this value
|
||||
# Replace getitem[0] user of the auto-functionalized node
|
||||
# with the new defunctionalized node directly if it exists
|
||||
users = self.getitem_users(node)
|
||||
if 0 in users:
|
||||
user = users[0]
|
||||
user.replace_all_uses_with(fn_node)
|
||||
self._remove(user)
|
||||
|
||||
134
vllm/compilation/passes/utility/scatter_split_replace.py
Normal file
134
vllm/compilation/passes/utility/scatter_split_replace.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Replace ``slice_scatter`` and ``split_with_sizes`` nodes with a single
|
||||
assignment if there are no users for the inplace tensor written to by
|
||||
the slice_scatter call.
|
||||
|
||||
The inplace rotary_embedding custom op takes in mutable query and key inputs
|
||||
that are split+getitem outputs of a single qkv tensor.
|
||||
When functionalized, we fetch the rotated query and key from the functionalized op
|
||||
using `getitem` calls. However, we also write to the qkv tensor inplace using a
|
||||
`slice_scatter`, then split the inplace tensor to get the output tensors again.
|
||||
Instead, if the inplace tensor has no subsequent users, we can just replace the
|
||||
`slice_scatter` and `split_with_sizes` nodes with the `getitem` calls.
|
||||
|
||||
This is already done in fix_functionalization::FixFunctionalizationPass, but
|
||||
writing a custom pass for it before defunctionalization allows matching against the
|
||||
qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass.
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ScatterSplitReplacementPass(VllmInductorPass):
|
||||
"""Replace getitem+slice_scatter+split nodes with a single getitem when
|
||||
the inplace subtensor written to by the slice_scatter has no other users.
|
||||
|
||||
Here's an example graph with q_size = 512, kv_size = 64:
|
||||
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
|
||||
q = operator.getitem(at, 1)
|
||||
k = operator.getitem(at, 2)
|
||||
torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1)
|
||||
torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1)
|
||||
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
q = operator.getitem(split_with_sizes_2, 0)
|
||||
k = operator.getitem(split_with_sizes_2, 1)
|
||||
v = operator.getitem(split_with_sizes_2, 2)
|
||||
|
||||
After this pass, this sequence of nodes is replaced with:
|
||||
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
|
||||
q = operator.getitem(at, 1)
|
||||
k = operator.getitem(at, 2)
|
||||
v = operator.getitem(split_with_sizes_1, 2)
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
count = 0
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
getitem_nodes[user.args[1]] = user
|
||||
|
||||
if (
|
||||
is_func(query, operator.getitem)
|
||||
and is_func(key, operator.getitem)
|
||||
and query.args[0] == key.args[0]
|
||||
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
||||
and all(
|
||||
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||
for getitem_node in getitem_nodes.values()
|
||||
for user in getitem_node.users
|
||||
)
|
||||
):
|
||||
# Pattern where query and key are slices of a qkv tensor.
|
||||
# While functionalized, results at [1] and [2] are scattered
|
||||
# back into qkv, then split again to get query and key.
|
||||
# If the inplace tensor has no other users, we can replace
|
||||
# the slice_scatter+split nodes with the original results.
|
||||
for user in getitem_nodes[1].users:
|
||||
slice_scatter_1_node = user
|
||||
if not is_func(
|
||||
slice_scatter_1_node, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
continue
|
||||
|
||||
for user in getitem_nodes[2].users:
|
||||
slice_scatter_2_node = user
|
||||
if not is_func(
|
||||
slice_scatter_2_node, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
continue
|
||||
|
||||
for user in slice_scatter_2_node.users:
|
||||
split_node = user
|
||||
if not is_func(split_node, torch.ops.aten.split_with_sizes.default):
|
||||
continue
|
||||
|
||||
split_getitem_users = {}
|
||||
for user in split_node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
split_getitem_users[user.args[1]] = user
|
||||
|
||||
# Replace query node
|
||||
split_getitem_users[0].replace_all_uses_with(getitem_nodes[1])
|
||||
graph.erase_node(split_getitem_users[0])
|
||||
# Replace key node
|
||||
split_getitem_users[1].replace_all_uses_with(getitem_nodes[2])
|
||||
graph.erase_node(split_getitem_users[1])
|
||||
# Redirect value node to original qkv tensor
|
||||
split_getitem_users[2].replace_input_with(split_node, query.args[0])
|
||||
|
||||
# Erase unused nodes
|
||||
graph.erase_node(split_node)
|
||||
graph.erase_node(slice_scatter_2_node)
|
||||
graph.erase_node(slice_scatter_1_node)
|
||||
|
||||
count += 1
|
||||
|
||||
logger.debug("Eliminated %d slice_scatter+split nodes", count)
|
||||
@@ -127,6 +127,13 @@ class PassConfig:
|
||||
# ROCm/AITER specific fusions
|
||||
fuse_act_padding: bool = Field(default=None)
|
||||
"""Fuse the custom RMSNorm + padding ops."""
|
||||
fuse_rope_kvcache: bool = Field(default=None)
|
||||
"""Fuse the QK rope + KV cache ops."""
|
||||
|
||||
rope_kvcache_fusion_max_token_num: int = 256
|
||||
"""The threshold for ROCm AITER RoPE+KVCache fusion e.g. for small batch decode.
|
||||
Larger batch sizes e.g. during prefill will use the unfused kernels.
|
||||
"""
|
||||
|
||||
fi_allreduce_fusion_max_size_mb: float | None = None
|
||||
"""The threshold of the communicated tensor sizes under which
|
||||
@@ -198,6 +205,7 @@ class PassConfig:
|
||||
"fuse_gemm_comms",
|
||||
"fuse_allreduce_rms",
|
||||
"fuse_act_padding",
|
||||
"fuse_rope_kvcache",
|
||||
mode="wrap",
|
||||
)
|
||||
@classmethod
|
||||
@@ -243,6 +251,12 @@ class PassConfig:
|
||||
"The fusion will be disabled."
|
||||
)
|
||||
self.fuse_act_padding = False
|
||||
if self.fuse_rope_kvcache and not current_platform.is_rocm():
|
||||
logger.warning_once(
|
||||
"KV cache fusion currently only enabled on ROCm. "
|
||||
"The fusion will be disabled."
|
||||
)
|
||||
self.fuse_rope_kvcache = False
|
||||
|
||||
|
||||
class DynamicShapesType(str, enum.Enum):
|
||||
@@ -824,6 +838,19 @@ class CompilationConfig:
|
||||
# TODO(zhuhaoran): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
if self.pass_config.fuse_rope_kvcache:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_triton_rotary_embed_enabled():
|
||||
logger.warning(
|
||||
"Cannot use VLLM_ROCM_USE_AITER_TRITON_ROPE with "
|
||||
"fuse_rope_kvcache. Disabling fuse_rope_kvcache."
|
||||
)
|
||||
self.pass_config.fuse_rope_kvcache = False
|
||||
else:
|
||||
# TODO(Rohan138): support rope native forward match and remove this.
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/28042
|
||||
self.custom_ops.append("+rotary_embedding")
|
||||
|
||||
if (
|
||||
is_torch_equal_or_newer("2.9.0.dev")
|
||||
|
||||
@@ -1401,6 +1401,20 @@ class VllmConfig:
|
||||
"allreduce-rms fusion will be enabled for all num_tokens."
|
||||
)
|
||||
|
||||
if compilation_config.pass_config.fuse_rope_kvcache:
|
||||
max_token_num = (
|
||||
compilation_config.pass_config.rope_kvcache_fusion_max_token_num
|
||||
)
|
||||
if max_token_num is not None:
|
||||
if compile_range_end is not None and max_token_num < compile_range_end:
|
||||
computed_compile_ranges_split_points.append(max_token_num)
|
||||
else:
|
||||
logger.debug(
|
||||
"Max num batched tokens below rope+kvcache fusion threshold, "
|
||||
"rope+kvcache fusion enabled for num_tokens <= %d.",
|
||||
compile_range_end,
|
||||
)
|
||||
|
||||
if compilation_config.compile_ranges_split_points is not None:
|
||||
for x in compilation_config.compile_ranges_split_points:
|
||||
assert isinstance(x, int)
|
||||
|
||||
@@ -570,11 +570,11 @@ direct_register_custom_op(
|
||||
|
||||
def get_attention_context(
|
||||
layer_name: str,
|
||||
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor]:
|
||||
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor, torch.Tensor]:
|
||||
"""Extract attention context for a given layer.
|
||||
|
||||
This helper function extracts the attention metadata, attention layer
|
||||
instance, and KV cache tensor for a specific layer.
|
||||
instance, KV cache tensor, and slot mapping for a specific layer.
|
||||
|
||||
Args:
|
||||
layer_name: The name/identifier of the attention layer.
|
||||
@@ -585,6 +585,7 @@ def get_attention_context(
|
||||
no metadata available
|
||||
- attn_layer: The attention layer instance (Attention or MLAAttention)
|
||||
- kv_cache: The KV cache tensor for current virtual engine
|
||||
- slot_mapping: The slot mapping for this specific layer
|
||||
|
||||
Note: attn_metadata may be None, but attn_layer and kv_cache are always
|
||||
extracted from the forward context.
|
||||
@@ -593,9 +594,14 @@ def get_attention_context(
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
return attn_metadata, attn_layer, kv_cache
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
)
|
||||
layer_slot_mapping = slot_mapping.get(layer_name)
|
||||
return attn_metadata, attn_layer, kv_cache, layer_slot_mapping
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
@@ -605,7 +611,7 @@ def unified_attention(
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
@@ -636,15 +642,7 @@ def unified_kv_cache_update(
|
||||
Returns a dummy that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
)
|
||||
layer_slot_mapping = slot_mapping.get(layer_name)
|
||||
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
|
||||
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
|
||||
@@ -691,7 +689,7 @@ def unified_attention_with_output(
|
||||
# that ensures torch.compile preserves ordering between KV cache update and
|
||||
# attention forward.
|
||||
del kv_cache_dummy_dep
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
|
||||
self.impl.forward(
|
||||
self,
|
||||
|
||||
@@ -40,8 +40,8 @@ def maybe_transfer_kv_layer(func: Callable) -> Callable:
|
||||
|
||||
layer_name: str = args[layer_name_index]
|
||||
|
||||
# Extract attention context (layer-specific metadata, layer, and kv_cache)
|
||||
attn_metadata, attn_layer, kv_cache = get_attention_context(layer_name)
|
||||
# Extract attention context (metadata, layer, kv_cache, layer_slot_mapping)
|
||||
attn_metadata, _, kv_cache, _ = get_attention_context(layer_name)
|
||||
connector = get_kv_transfer_group()
|
||||
if attn_metadata is None or not connector.has_connector_metadata():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@@ -828,7 +828,7 @@ def unified_mla_attention(
|
||||
k_pe: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata, layer, kv_cache = get_attention_context(layer_name)
|
||||
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
|
||||
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
@@ -862,7 +862,7 @@ def unified_mla_attention_with_output(
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
attn_metadata, layer, kv_cache = get_attention_context(layer_name)
|
||||
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
|
||||
layer.forward_impl(
|
||||
q,
|
||||
kv_c_normed,
|
||||
|
||||
@@ -723,6 +723,33 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""
|
||||
return False
|
||||
|
||||
def fused_rope_kvcache_supported(self):
|
||||
"""
|
||||
Does this attention implementation support RoPE+KVCache fusion.
|
||||
This is used by the RopeKVCacheFusionPass to only fuse the RoPE ops
|
||||
with the KV cache update for implementations that support it.
|
||||
"""
|
||||
return False
|
||||
|
||||
def do_rope_and_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
If `fused_rope_kvcache_supported` returns True, this method will be called
|
||||
by torch.ops.vllm.fused_rope_and_unified_kv_cache_update
|
||||
to perform the inplace RoPE and KV cache update.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MLAAttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
"""MLA attention implementation with forward_mqa and forward_mha methods."""
|
||||
|
||||
@@ -11,7 +11,6 @@ from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.attention.attention import get_attention_context
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
@@ -1290,11 +1289,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
attn_metadata, _, _ = get_attention_context(layer.layer_name)
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
@@ -1303,45 +1297,40 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping
|
||||
# is not padded. However, we don't need to do
|
||||
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
||||
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
||||
# to determine the number of actual tokens.
|
||||
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||
# We may calculate per token quant scale in
|
||||
# reshape_and_cache_shuffle_triton which might differ from
|
||||
# vllm's style when shuffle layout is used.
|
||||
k_scale = attn_metadata.k_scale
|
||||
v_scale = attn_metadata.v_scale
|
||||
assert k_scale is not None and v_scale is not None, (
|
||||
"k_scale and v_scale are required for shuffled update"
|
||||
)
|
||||
reshape_and_cache_shuffle_triton(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping
|
||||
# is not padded. However, we don't need to do
|
||||
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
||||
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
||||
# to determine the number of actual tokens.
|
||||
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||
# We may calculate per token quant scale in
|
||||
# reshape_and_cache_shuffle_triton which might differ from
|
||||
# vllm's style when shuffle layout is used.
|
||||
k_scale = layer._k_scale
|
||||
v_scale = layer._v_scale
|
||||
assert k_scale is not None and v_scale is not None, (
|
||||
"k_scale and v_scale are required for shuffled update"
|
||||
)
|
||||
reshape_and_cache_shuffle_triton(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
@@ -207,3 +208,42 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
def fused_rope_kvcache_supported(self):
|
||||
return rocm_aiter_ops.is_enabled()
|
||||
|
||||
def do_rope_and_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
flash_layout = True
|
||||
|
||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||
if is_fp8_kv_cache:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
|
||||
rocm_aiter_ops.triton_rope_and_cache(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
key_cache,
|
||||
value_cache,
|
||||
layer_slot_mapping,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
flash_layout,
|
||||
is_fp8_kv_cache,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
@@ -415,3 +416,46 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
def fused_rope_kvcache_supported(self):
|
||||
return rocm_aiter_ops.is_enabled()
|
||||
|
||||
def do_rope_and_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache,
|
||||
layer.num_kv_heads, # type: ignore[attr-defined]
|
||||
layer.head_size, # type: ignore[attr-defined]
|
||||
)
|
||||
flash_layout = False
|
||||
|
||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||
if is_fp8_kv_cache:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
|
||||
rocm_aiter_ops.triton_rope_and_cache(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
key_cache,
|
||||
value_cache,
|
||||
layer_slot_mapping,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
flash_layout,
|
||||
is_fp8_kv_cache,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
@@ -596,3 +597,42 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
def fused_rope_kvcache_supported(self):
|
||||
return rocm_aiter_ops.is_enabled()
|
||||
|
||||
def do_rope_and_kv_cache_update(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
kv_cache: torch.Tensor,
|
||||
layer_slot_mapping: torch.Tensor,
|
||||
):
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
flash_layout = True
|
||||
|
||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||
if is_fp8_kv_cache:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
|
||||
rocm_aiter_ops.triton_rope_and_cache(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
key_cache,
|
||||
value_cache,
|
||||
layer_slot_mapping,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
flash_layout,
|
||||
is_fp8_kv_cache,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user