[V1][Kernel] Flashinfer HND KV cache layout (#19280)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-06-17 15:09:22 +02:00
committed by GitHub
parent 93aee29fdb
commit 4c8f64faa7
6 changed files with 64 additions and 20 deletions

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses import dataclasses
import os
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
@@ -50,8 +49,7 @@ if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT", FLASHINFER_KV_CACHE_LAYOUT: str = envs.VLLM_KV_CACHE_LAYOUT or "NHD"
"NHD").upper()
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):

View File

@@ -3,7 +3,6 @@
""" """
KV cache helper for store. KV cache helper for store.
""" """
import torch import torch
import vllm.envs as envs import vllm.envs as envs
@@ -94,15 +93,17 @@ class model_aware_kv_ops_helper:
def get_kv_connector_cache_layout(): def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# used for faster transfer.
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config kv_config = vllm_config.kv_transfer_config
if vllm_config.model_config is None: if vllm_config.model_config is None or kv_config is None:
logger.warning("Unable to detect current VLLM config. " \ logger.warning_once("Unable to detect current VLLM config. " \
"Defaulting to NHD kv cache layout.") "Defaulting to NHD kv cache layout.")
else: else:
use_mla = vllm_config.model_config.use_mla use_mla = vllm_config.model_config.use_mla
if not use_mla and kv_config.kv_connector == "NixlConnector": if not use_mla and kv_config.kv_connector == "NixlConnector":
logger.info("NixlConnector detected. Setting KV cache " \ logger.info_once("NixlConnector detected. Setting KV cache " \
"layout to HND for better xfer performance.") "layout to HND for better xfer performance.")
return "HND" return "HND"
return "NHD" return "NHD"

View File

@@ -128,6 +128,7 @@ if TYPE_CHECKING:
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_SLEEP_WHEN_IDLE: bool = False
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
def get_default_cache_root(): def get_default_cache_root():
@@ -879,6 +880,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
# processes via zmq. # processes via zmq.
"VLLM_MQ_MAX_CHUNK_BYTES_MB": "VLLM_MQ_MAX_CHUNK_BYTES_MB":
lambda: int(os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16")), lambda: int(os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16")),
# KV Cache layout used throughout vllm.
# Some common values are:
# - NHD
# - HND
# Where N=num_blocks, H=num_heads and D=head_size. The default value will
# leave the layout choice to the backend. Mind that backends may only
# implement and support a subset of all possible layouts.
"VLLM_KV_CACHE_LAYOUT":
lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None)
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]

View File

@@ -16,13 +16,12 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version) get_flash_attn_version)
from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
@@ -73,16 +72,15 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]: def get_kv_cache_stride_order() -> tuple[int, ...]:
# NOTE When running disaggregated PD with NIXL, HND layout is used for # `stride_order` indicates the permutation that gets
# faster transfer. `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want. # us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_connector_cache_layout() cache_layout = get_kv_cache_layout()
if cache_layout == "NHD": if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4) stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND": elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4) stride_order = (0, 1, 3, 2, 4)
else: else:
raise ValueError("Unknown cache layout format %s.", cache_layout) raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order return stride_order

View File

@@ -19,7 +19,8 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
@@ -66,6 +67,19 @@ class FlashInferBackend(AttentionBackend):
) -> tuple[int, ...]: ) -> tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size) return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets us from
# `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@dataclass @dataclass
class PerLayerParameters: class PerLayerParameters:
@@ -290,7 +304,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def _get_prefill_wrapper(self): def _get_prefill_wrapper(self):
if self._prefill_wrapper is None: if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD") self._get_workspace_buffer(), get_kv_cache_layout())
return self._prefill_wrapper return self._prefill_wrapper
def _get_decode_wrapper(self): def _get_decode_wrapper(self):
@@ -303,14 +317,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_qo_heads // num_kv_heads > 4) num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(), self._get_workspace_buffer(),
"NHD", get_kv_cache_layout(),
use_tensor_cores=use_tensor_cores) use_tensor_cores=use_tensor_cores)
return self._decode_wrapper return self._decode_wrapper
def _get_cascade_wrapper(self): def _get_cascade_wrapper(self):
if self._cascade_wrapper is None: if self._cascade_wrapper is None:
self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
2, self._get_workspace_buffer(), "NHD") 2, self._get_workspace_buffer(), get_kv_cache_layout())
return self._cascade_wrapper return self._cascade_wrapper
def _plan(self, attn_metadata: FlashInferMetadata): def _plan(self, attn_metadata: FlashInferMetadata):
@@ -620,6 +634,7 @@ class FlashInferImpl(AttentionImpl):
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
stride_order = FlashInferBackend.get_kv_cache_stride_order()
# Regular attention (common case). # Regular attention (common case).
# Decodes are at the front and prefills are at the back, # Decodes are at the front and prefills are at the back,
# according to reorder_batch() # according to reorder_batch()
@@ -634,7 +649,7 @@ class FlashInferImpl(AttentionImpl):
assert prefill_wrapper._sm_scale == self.scale assert prefill_wrapper._sm_scale == self.scale
prefill_wrapper.run( prefill_wrapper.run(
prefill_query, prefill_query,
kv_cache, kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float, k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float, v_scale=layer._v_scale_float,
out=output[num_decode_tokens:], out=output[num_decode_tokens:],
@@ -650,7 +665,7 @@ class FlashInferImpl(AttentionImpl):
assert decode_wrapper._sm_scale == self.scale assert decode_wrapper._sm_scale == self.scale
decode_wrapper.run( decode_wrapper.run(
decode_query, decode_query,
kv_cache, kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float, k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float, v_scale=layer._v_scale_float,
out=output[:num_decode_tokens], out=output[:num_decode_tokens],

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc import abc
import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
@@ -12,6 +13,13 @@ if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass @dataclass
class CommonAttentionMetadata: class CommonAttentionMetadata:
@@ -119,3 +127,16 @@ def validate_kv_sharing_target(current_layer_name, target_layer_name,
raise ValueError( raise ValueError(
error_msg + error_msg +
f"must be the same type as the current layer ({expected}).") f"must be the same type as the current layer ({expected}).")
@functools.lru_cache
def get_kv_cache_layout():
# Override with format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
if cache_layout is None:
cache_layout = get_kv_connector_cache_layout()
else:
logger.info_once("`FLASHINFER_KV_CACHE_LAYOUT` environment variable " \
"detected. Setting KV cache layout to %s.", cache_layout)
return cache_layout