[V1][Kernel] Flashinfer HND KV cache layout (#19280)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-06-17 15:09:22 +02:00
committed by GitHub
parent 93aee29fdb
commit 4c8f64faa7
6 changed files with 64 additions and 20 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
@@ -12,6 +13,13 @@ if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class CommonAttentionMetadata:
@@ -119,3 +127,16 @@ def validate_kv_sharing_target(current_layer_name, target_layer_name,
raise ValueError(
error_msg +
f"must be the same type as the current layer ({expected}).")
@functools.lru_cache
def get_kv_cache_layout():
# Override with format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
if cache_layout is None:
cache_layout = get_kv_connector_cache_layout()
else:
logger.info_once("`FLASHINFER_KV_CACHE_LAYOUT` environment variable " \
"detected. Setting KV cache layout to %s.", cache_layout)
return cache_layout