[XPU] Set consistent default KV cache layout (#24745)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -5,8 +5,8 @@ import enum
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, fields, make_dataclass
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol,
|
||||
TypeVar)
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional,
|
||||
Protocol, TypeVar, Union, get_args)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -30,7 +30,12 @@ from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_KV_CACHE_LAYOUT_OVERRIDE = None
|
||||
KVCacheLayoutType = Literal["NHD", "HND"]
|
||||
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
|
||||
|
||||
|
||||
def is_valid_kv_cache_layout(value: str) -> bool:
|
||||
return value in get_args(KVCacheLayoutType)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -296,12 +301,13 @@ def get_kv_cache_layout():
|
||||
if cache_layout is None:
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
else:
|
||||
assert is_valid_kv_cache_layout(cache_layout)
|
||||
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||
return cache_layout
|
||||
|
||||
|
||||
def set_kv_cache_layout(cache_layout: str):
|
||||
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
|
||||
global _KV_CACHE_LAYOUT_OVERRIDE
|
||||
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
|
||||
|
||||
|
||||
Reference in New Issue
Block a user