[TPU][V1] Remove ragged attention kernel parameter hard coding (#16041)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao
2025-04-04 04:48:50 -07:00
committed by GitHub
parent 86cbd2eee9
commit fadc59c0e6
2 changed files with 8 additions and 20 deletions

View File

@@ -11,10 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
# These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK = 32
NUM_KV_PAGES_PER_BLOCK = 128
class PallasAttentionBackend(AttentionBackend):
@@ -115,13 +111,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
tpu_version = torch_xla.tpu.version()
if tpu_version < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
if tpu_version == 4:
self.vmem_limit_bytes = 16 * 1024 * 1024
else:
self.vmem_limit_bytes = 64 * 1024 * 1024
def forward(
self,
@@ -165,9 +154,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata.block_tables,
attn_metadata.query_start_loc,
attn_metadata.num_seqs,
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
vmem_limit_bytes=self.vmem_limit_bytes,
# By default, the system utilizes optimized block size and
# vmem_limit_bytes parameters from the kernel repository. However,
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
use_kernel=True,
sm_scale=self.scale,
sliding_window=self.sliding_window,