[TPU][V1] Remove ragged attention kernel parameter hard coding (#16041)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
@@ -11,10 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
AttentionLayer, AttentionType)
|
AttentionLayer, AttentionType)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
|
|
||||||
# These are the 2 tunable parameters of the paged attention Pallas kernel.
|
|
||||||
NUM_QUERIES_PER_BLOCK = 32
|
|
||||||
NUM_KV_PAGES_PER_BLOCK = 128
|
|
||||||
|
|
||||||
|
|
||||||
class PallasAttentionBackend(AttentionBackend):
|
class PallasAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
@@ -115,13 +111,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
tpu_version = torch_xla.tpu.version()
|
tpu_version = torch_xla.tpu.version()
|
||||||
if tpu_version < 4:
|
if tpu_version < 4:
|
||||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||||
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
|
|
||||||
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
|
|
||||||
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
|
|
||||||
if tpu_version == 4:
|
|
||||||
self.vmem_limit_bytes = 16 * 1024 * 1024
|
|
||||||
else:
|
|
||||||
self.vmem_limit_bytes = 64 * 1024 * 1024
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -165,9 +154,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_metadata.block_tables,
|
attn_metadata.block_tables,
|
||||||
attn_metadata.query_start_loc,
|
attn_metadata.query_start_loc,
|
||||||
attn_metadata.num_seqs,
|
attn_metadata.num_seqs,
|
||||||
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
|
# By default, the system utilizes optimized block size and
|
||||||
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
|
# vmem_limit_bytes parameters from the kernel repository. However,
|
||||||
vmem_limit_bytes=self.vmem_limit_bytes,
|
# these can be manually adjusted for debugging if necessary.
|
||||||
|
num_kv_pages_per_block=None,
|
||||||
|
num_queries_per_block=None,
|
||||||
|
vmem_limit_bytes=None,
|
||||||
use_kernel=True,
|
use_kernel=True,
|
||||||
sm_scale=self.scale,
|
sm_scale=self.scale,
|
||||||
sliding_window=self.sliding_window,
|
sliding_window=self.sliding_window,
|
||||||
|
|||||||
@@ -24,8 +24,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
|||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||||
PallasAttentionBackend,
|
|
||||||
PallasMetadata)
|
PallasMetadata)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
@@ -155,11 +154,8 @@ class TPUModelRunner:
|
|||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||||
|
|
||||||
padded_max_num_blocks_per_req = _get_padded_number(
|
|
||||||
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
|
|
||||||
self.block_table_cpu = torch.zeros(
|
self.block_table_cpu = torch.zeros(
|
||||||
(self.max_num_tokens, padded_max_num_blocks_per_req),
|
(self.max_num_tokens, self.max_num_blocks_per_req),
|
||||||
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user