[V1][TPU] Support V1 Sampler for ragged attention (#14227)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-03-20 05:00:39 +01:00
committed by GitHub
parent 40828ce5fe
commit d8c6d7d6b5
6 changed files with 535 additions and 55 deletions

View File

@@ -23,13 +23,16 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
PallasAttentionBackend,
PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
ModelRunnerOutput, SamplerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@@ -42,6 +45,8 @@ logger = init_logger(__name__)
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
class TPUModelRunner:
@@ -138,8 +143,10 @@ class TPUModelRunner:
device="cpu")
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
padded_max_num_blocks_per_req = _get_padded_number(
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
self.block_table_cpu = torch.zeros(
(self.max_num_tokens, self.max_num_blocks_per_req),
(self.max_num_tokens, padded_max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
device="cpu")
@@ -267,6 +274,9 @@ class TPUModelRunner:
req_data.num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
@@ -284,6 +294,10 @@ class TPUModelRunner:
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
# TODO This slices tensors to copy to device, triggering recompilation.
if batch_changed:
self.input_batch.refresh_sampling_metadata()
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def get_model(self) -> nn.Module:
@@ -447,6 +461,8 @@ class TPUModelRunner:
# TODO: Support prompt logprobs.
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs, self.max_num_reqs)
# Indices at which we sample (positions of last token in the sequence).
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices
@@ -576,7 +592,14 @@ class TPUModelRunner:
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids
inputs_embeds = None
sampling_metadata = self.input_batch.sampling_metadata
num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: if there's any shape
# mismatch in pre-processing, it will trigger a small recompilation
# of the code thus far. Forward graph remains untouched.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_sampling_metadata(sampling_metadata, logits_indices,
num_reqs, self.device)
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
@@ -585,12 +608,13 @@ class TPUModelRunner:
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
num_reqs = self.input_batch.num_reqs
selected_token_ids = self.model.compute_logits(hidden_states,
logits_indices, None)
selected_token_ids = self.model.sample_from_hidden(
hidden_states, tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
# Then, let's update the cache state.
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
@@ -607,7 +631,6 @@ class TPUModelRunner:
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
@@ -620,6 +643,7 @@ class TPUModelRunner:
max_gen_len = selected_token_ids.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = selected_token_ids.tolist()
for i, req_state, seq_len in request_seq_lens:
token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
@@ -676,11 +700,8 @@ class TPUModelRunner:
fullgraph=True,
dynamic=False)
def _dummy_run(
self,
kv_caches,
num_tokens: int,
) -> None:
@torch.no_grad()
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
if self.is_multimodal_model:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
@@ -729,32 +750,10 @@ class TPUModelRunner:
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
)
num_reqs = _get_padded_num_reqs_with_upper_limit(
64, self.max_num_reqs)
# NOTE(chengjiyao): In total, the compute_logits function utilizes a
# compilation cache size of token_bucket_num multiplied by
# req_bucket_num. This is acceptable, given the graph's relatively
# small size.
while True:
logits_indices = torch.zeros(
num_reqs,
dtype=torch.int32,
device=self.device,
)
torch._dynamo.mark_dynamic(hidden_states, 0)
torch._dynamo.mark_dynamic(logits_indices, 0)
self.model.compute_logits(hidden_states, logits_indices, None)
if num_reqs >= self.max_num_reqs:
break
num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs + 1, self.max_num_reqs)
self.model(input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds)
def capture_model(self) -> None:
"""Compile the model."""
@@ -764,13 +763,51 @@ class TPUModelRunner:
start = time.perf_counter()
num_tokens = 16
while True:
self._dummy_run(self.kv_caches, num_tokens)
logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(self.kv_caches, num_tokens)
xm.mark_step()
xm.wait_device_ops()
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter()
num_tokens = 16
hsize = self.model_config.get_hidden_size()
device = self.device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
while True:
num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
dtype=torch.bfloat16)
while True:
# Default metadata is an all_greedy setup. But since the
# `do_argmax` flag is a tensor, we still compile the full graph
meta = self.input_batch.sampling_metadata
indices = torch.zeros(
num_reqs_to_sample,
dtype=torch.int32,
device=device,
)
sampling_meta = TPUSupportedSamplingMetadata.\
from_sampling_metadata(meta, indices,
num_reqs_to_sample, device)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample)
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
xm.mark_step()
if num_reqs_to_sample >= self.max_num_reqs:
break
num_reqs_to_sample *= 2
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
@@ -818,6 +855,13 @@ class ModelWrapperV1(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
self.sampler = TPUSampler()
def sample(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out
def forward(
self,
@@ -826,7 +870,7 @@ class ModelWrapperV1(nn.Module):
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
"""Executes the forward pass of the model.
Args:
input_ids: The input token IDs of shape [num_tokens].
@@ -837,7 +881,6 @@ class ModelWrapperV1(nn.Module):
hidden_size]. It is used for multimodal models.
"""
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
@@ -846,17 +889,33 @@ class ModelWrapperV1(nn.Module):
return hidden_states
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def compute_logits(
def sample_from_hidden(
self,
hidden_states: torch.Tensor,
logits_indices: torch.Tensor,
sampling_metadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, sampling_metadata)
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
return selected_token_ids
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states)
# Greedy sampling can't be run without branching the graph on Sampler.
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
# NOTE do_argmax is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.do_argmax,
torch.argmax(logits, dim=-1, keepdim=True),
self.sample(logits, sampling_metadata)\
.sampled_token_ids)
return out_tokens
def compute_logits(self,
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits = self.model.compute_logits(hidden_states, None)
return logits
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)
@@ -876,5 +935,5 @@ def _get_padded_token_len(x: int) -> int:
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
res = 64 if x <= 64 else 1 << (x - 1).bit_length()
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
return min(res, upper_limit)