[TPU] Optimize kv cache update kernel (#20415)
Signed-off-by: Yifei Teng <tengyifei88@gmail.com>
This commit is contained in:
@@ -947,6 +947,13 @@ def next_power_of_2(n) -> int:
|
|||||||
return 1 << (n - 1).bit_length()
|
return 1 << (n - 1).bit_length()
|
||||||
|
|
||||||
|
|
||||||
|
def prev_power_of_2(n: int) -> int:
|
||||||
|
"""The previous power of 2 (inclusive)"""
|
||||||
|
if n <= 0:
|
||||||
|
return 0
|
||||||
|
return 1 << (n.bit_length() - 1)
|
||||||
|
|
||||||
|
|
||||||
def round_up(x: int, y: int) -> int:
|
def round_up(x: int, y: int) -> int:
|
||||||
return ((x + y - 1) // y) * y
|
return ((x + y - 1) // y) * y
|
||||||
|
|
||||||
|
|||||||
@@ -324,3 +324,9 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
|||||||
page_size: int,
|
page_size: int,
|
||||||
num_slices_per_block: int) -> torch.Tensor:
|
num_slices_per_block: int) -> torch.Tensor:
|
||||||
return kv_cache
|
return kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int,
|
||||||
|
kv_cache_dtype: torch.dtype) -> int:
|
||||||
|
"""Returns the size in bytes of one page of the KV cache."""
|
||||||
|
return block_size * num_kv_heads * head_size * kv_cache_dtype.itemsize
|
||||||
|
|||||||
@@ -31,9 +31,10 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
|||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
|
||||||
is_pin_memory_available)
|
is_pin_memory_available, prev_power_of_2)
|
||||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||||
PallasMetadata)
|
PallasMetadata,
|
||||||
|
get_page_size_bytes)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||||
KVCacheConfig, KVCacheSpec,
|
KVCacheConfig, KVCacheSpec,
|
||||||
@@ -56,8 +57,6 @@ logger = init_logger(__name__)
|
|||||||
INVALID_TOKEN_ID = -1
|
INVALID_TOKEN_ID = -1
|
||||||
# Smallest output size
|
# Smallest output size
|
||||||
MIN_NUM_SEQS = 8
|
MIN_NUM_SEQS = 8
|
||||||
# Block size used for kv cache updating kernel
|
|
||||||
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8
|
|
||||||
|
|
||||||
|
|
||||||
#########################################################
|
#########################################################
|
||||||
@@ -139,7 +138,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.pin_memory = is_pin_memory_available()
|
self.pin_memory = is_pin_memory_available()
|
||||||
self.dtype = self.model_config.dtype
|
self.dtype = self.model_config.dtype
|
||||||
if cache_config.cache_dtype == "auto":
|
if cache_config.cache_dtype == "auto":
|
||||||
self.kv_cache_dtype = self.dtype
|
model_dtype = self.dtype
|
||||||
|
if isinstance(model_dtype, str):
|
||||||
|
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
||||||
|
else:
|
||||||
|
self.kv_cache_dtype = model_dtype
|
||||||
else:
|
else:
|
||||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||||
cache_config.cache_dtype]
|
cache_config.cache_dtype]
|
||||||
@@ -192,6 +195,14 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||||
self.encoder_cache_size = encoder_cache_size
|
self.encoder_cache_size = encoder_cache_size
|
||||||
|
|
||||||
|
self._num_slices_per_kv_cache_update_block = \
|
||||||
|
_get_num_slices_per_kv_cache_update_block(get_page_size_bytes(
|
||||||
|
block_size=self.block_size,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
head_size=self.head_size,
|
||||||
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
|
))
|
||||||
|
|
||||||
# Lazy initialization
|
# Lazy initialization
|
||||||
self.model: nn.Module # Set after load_model
|
self.model: nn.Module # Set after load_model
|
||||||
self.kv_caches: list[torch.Tensor] = []
|
self.kv_caches: list[torch.Tensor] = []
|
||||||
@@ -719,7 +730,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_kv_update_slices = slot_mapping_metadata.shape[0]
|
num_kv_update_slices = slot_mapping_metadata.shape[0]
|
||||||
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
||||||
padded_total_num_scheduled_tokens, self.max_num_reqs,
|
padded_total_num_scheduled_tokens, self.max_num_reqs,
|
||||||
self.block_size)
|
self.block_size, self._num_slices_per_kv_cache_update_block)
|
||||||
slot_mapping_metadata = np.pad(
|
slot_mapping_metadata = np.pad(
|
||||||
slot_mapping_metadata,
|
slot_mapping_metadata,
|
||||||
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
|
[[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]],
|
||||||
@@ -750,8 +761,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_kv_update_slices=torch.tensor([num_kv_update_slices],
|
num_kv_update_slices=torch.tensor([num_kv_update_slices],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device),
|
device=self.device),
|
||||||
num_slices_per_kv_cache_update_block=
|
num_slices_per_kv_cache_update_block=self.
|
||||||
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
|
_num_slices_per_kv_cache_update_block,
|
||||||
)
|
)
|
||||||
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
|
||||||
# request in the batch. While we should not sample any token from this
|
# request in the batch. While we should not sample any token from this
|
||||||
@@ -1197,7 +1208,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
position_ids = torch.zeros(num_tokens,
|
position_ids = torch.zeros(num_tokens,
|
||||||
dtype=torch.int32).to(self.device)
|
dtype=torch.int32).to(self.device)
|
||||||
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
padded_num_slices = _get_padded_num_kv_cache_update_slices(
|
||||||
num_tokens, self.max_num_reqs, self.block_size)
|
num_tokens, self.max_num_reqs, self.block_size,
|
||||||
|
self._num_slices_per_kv_cache_update_block)
|
||||||
num_kv_update_slices = torch.tensor([padded_num_slices],
|
num_kv_update_slices = torch.tensor([padded_num_slices],
|
||||||
dtype=torch.int32).to(self.device)
|
dtype=torch.int32).to(self.device)
|
||||||
slot_mapping = torch.zeros((3, padded_num_slices),
|
slot_mapping = torch.zeros((3, padded_num_slices),
|
||||||
@@ -1220,8 +1232,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
num_seqs=num_seqs,
|
num_seqs=num_seqs,
|
||||||
num_kv_update_slices=num_kv_update_slices,
|
num_kv_update_slices=num_kv_update_slices,
|
||||||
num_slices_per_kv_cache_update_block=
|
num_slices_per_kv_cache_update_block=self.
|
||||||
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK,
|
_num_slices_per_kv_cache_update_block,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
@@ -1826,19 +1838,41 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int:
|
|||||||
return paddings[index]
|
return paddings[index]
|
||||||
|
|
||||||
|
|
||||||
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
|
def _get_padded_num_kv_cache_update_slices(
|
||||||
page_size: int) -> int:
|
num_tokens: int, max_num_reqs: int, page_size: int,
|
||||||
|
num_slices_per_kv_cache_update_block: int) -> int:
|
||||||
"""Calculates the padded number of KV cache update slices to avoid
|
"""Calculates the padded number of KV cache update slices to avoid
|
||||||
recompilation."""
|
recompilation."""
|
||||||
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
|
padded_num_slices = 2 * max_num_reqs + num_tokens // page_size
|
||||||
padded_num_slices = min(padded_num_slices, num_tokens)
|
padded_num_slices = min(padded_num_slices, num_tokens)
|
||||||
padded_num_slices = (
|
padded_num_slices = (
|
||||||
padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1
|
padded_num_slices + num_slices_per_kv_cache_update_block - 1
|
||||||
) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \
|
) // num_slices_per_kv_cache_update_block * \
|
||||||
NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK
|
num_slices_per_kv_cache_update_block
|
||||||
return padded_num_slices
|
return padded_num_slices
|
||||||
|
|
||||||
|
|
||||||
|
def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int:
|
||||||
|
"""Find the optimum number of slices to copy per Pallas program instance.
|
||||||
|
|
||||||
|
Increasing the number of slices copied in one instance of the kernel program
|
||||||
|
will increase HBM bandwidth utilization via more in-flight DMAs.
|
||||||
|
|
||||||
|
However, it will also use more VMEM, and experimentally, we observed
|
||||||
|
performance regression at 128 slices on v6e, likely due to running
|
||||||
|
out of scalar registers. Thus this function will limit the number of
|
||||||
|
slices to 64.
|
||||||
|
"""
|
||||||
|
# Conservative VMEM usage limit: 32 MiB
|
||||||
|
vmem_limit = 32 * 1024 * 1024
|
||||||
|
num_slices_per_block = vmem_limit // page_size_bytes
|
||||||
|
assert num_slices_per_block > 0, "Number of slices should be positive"
|
||||||
|
num_slices_per_block = prev_power_of_2(num_slices_per_block)
|
||||||
|
if num_slices_per_block > 64:
|
||||||
|
num_slices_per_block = 64
|
||||||
|
return num_slices_per_block
|
||||||
|
|
||||||
|
|
||||||
def replace_set_lora(model):
|
def replace_set_lora(model):
|
||||||
|
|
||||||
def _tpu_set_lora(
|
def _tpu_set_lora(
|
||||||
|
|||||||
Reference in New Issue
Block a user