[TPU] Support sliding window and logit soft capping in the paged attention kernel for TPU. (#15732)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
iefgnoix
2025-04-03 14:23:28 -07:00
committed by GitHub
parent 03a70eacaf
commit b6be6f8d1e
4 changed files with 128 additions and 18 deletions

View File

@@ -92,6 +92,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@@ -99,15 +101,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if logits_soft_cap is not None:
raise NotImplementedError(
"Attention logits soft-capping is not supported.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
@@ -172,7 +169,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
vmem_limit_bytes=self.vmem_limit_bytes,
use_kernel=True,
sm_scale=self.scale)
sm_scale=self.scale,
sliding_window=self.sliding_window,
soft_cap=self.logits_soft_cap,
)
return output.reshape(num_tokens, hidden_size)