[Hardware][TPU]Enable ragged paged attention kernel and resolve recompilation issue (#14310)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-03-06 15:31:05 -08:00
committed by GitHub
parent 04222984f8
commit 0578e5a462
3 changed files with 58 additions and 66 deletions

View File

@@ -14,7 +14,7 @@ import torch_xla.runtime as xr
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
@@ -416,8 +416,8 @@ class TPUModelRunner:
num_scheduled_tokens_per_req)
# Do the padding and copy the tensors to the TPU.
padded_total_num_scheduled_tokens = _get_padded_number(
total_num_scheduled_tokens, NUM_QUERIES_PER_BLOCK)
padded_total_num_scheduled_tokens = _get_padded_token_len(
total_num_scheduled_tokens)
self.input_ids = self.input_ids_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
@@ -428,23 +428,22 @@ class TPUModelRunner:
slot_mapping = self.slot_mapping_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
padded_block_table = self.block_table_cpu[:
padded_total_num_scheduled_tokens]
padded_block_table[:num_reqs, :self.max_num_blocks_per_req] = (
block_tables = self.block_table_cpu[:self.max_num_reqs]
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
padded_block_table = padded_block_table.to(self.device)
query_start_loc = self.query_start_loc_cpu[:
padded_total_num_scheduled_tokens
+ 1].to(self.device)
seq_lens = self.seq_lens_cpu[:padded_total_num_scheduled_tokens].to(
block_tables = block_tables.to(self.device)
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
self.device)
seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
block_tables=padded_block_table,
block_tables=block_tables,
context_lens=seq_lens,
query_start_loc=query_start_loc,
num_seqs=num_reqs,
num_seqs=torch.tensor([num_reqs],
dtype=torch.int32,
device=self.device),
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
@@ -693,29 +692,34 @@ class TPUModelRunner:
dtype=torch.int32,
device=self.device)
inputs_embeds = None
actual_num_reqs = min(num_tokens, self.max_num_reqs)
position_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros(num_tokens,
dtype=torch.int64,
device=self.device)
block_tables = torch.zeros((num_tokens, self.block_table_cpu.shape[1]),
dtype=torch.int32,
device=self.device)
query_lens = [1] * num_tokens
block_tables = torch.zeros(
(self.max_num_reqs, self.block_table_cpu.shape[1]),
dtype=torch.int32,
device=self.device)
query_lens = [1] * self.max_num_reqs
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.int32),
dim=0,
dtype=torch.int32).to(self.device)
context_lens = torch.ones((num_tokens, ),
context_lens = torch.ones((self.max_num_reqs, ),
dtype=torch.int32,
device=self.device)
num_seqs = torch.tensor([actual_num_reqs],
dtype=torch.int32,
device=self.device)
attn_metadata = PallasMetadata(
slot_mapping=slot_mapping,
block_tables=block_tables,
context_lens=context_lens,
query_start_loc=query_start_loc,
num_seqs=num_tokens,
num_seqs=num_seqs,
)
if self.is_multimodal_model:
@@ -724,9 +728,6 @@ class TPUModelRunner:
torch._dynamo.mark_dynamic(input_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
@@ -817,28 +818,6 @@ class ModelWrapperV1(nn.Module):
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
# Skip this in memory profiling at initialization.
if kv_caches[0][0].numel() > 0:
attn_metadata = get_forward_context().attn_metadata
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
# work, we need to flatten the first three dimensions and modify
# the slot_mapping accordingly.
# kv_caches: list[tuple[torch.Tensor, torch.Tensor]]
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
slot_mapping = attn_metadata.slot_mapping
slot_mapping = slot_mapping.flatten()
head_indicies = torch.arange(0,
num_kv_heads,
device=slot_mapping.device,
dtype=slot_mapping.dtype)
head_indicies *= block_size * num_blocks
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
-1, num_kv_heads)
slot_mapping = slot_mapping + head_indicies.view(1, -1)
slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping
assert self.model is not None
hidden_states = self.model(
@@ -866,3 +845,9 @@ class ModelWrapperV1(nn.Module):
def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple
def _get_padded_token_len(x: int) -> int:
if x <= 16:
return 16
return 1 << (x - 1).bit_length()