[V1][TPU] Support V1 Sampler for ragged attention (#14227)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -23,13 +23,16 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||
PallasAttentionBackend,
|
||||
PallasMetadata)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
ModelRunnerOutput, SamplerOutput)
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
@@ -42,6 +45,8 @@ logger = init_logger(__name__)
|
||||
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
||||
_PAD_SLOT_ID = 1_000_000_000
|
||||
INVALID_TOKEN_ID = -1
|
||||
# Smallest output size
|
||||
MIN_NUM_SEQS = 8
|
||||
|
||||
|
||||
class TPUModelRunner:
|
||||
@@ -138,8 +143,10 @@ class TPUModelRunner:
|
||||
device="cpu")
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
|
||||
padded_max_num_blocks_per_req = _get_padded_number(
|
||||
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(self.max_num_tokens, self.max_num_blocks_per_req),
|
||||
(self.max_num_tokens, padded_max_num_blocks_per_req),
|
||||
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
||||
device="cpu")
|
||||
|
||||
@@ -267,6 +274,9 @@ class TPUModelRunner:
|
||||
req_data.num_computed_tokens)
|
||||
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
||||
req_index)
|
||||
# Check if the batch has changed. If not, we can skip copying the
|
||||
# sampling metadata from CPU to GPU.
|
||||
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
@@ -284,6 +294,10 @@ class TPUModelRunner:
|
||||
# Condense the batched states if there are empty indices.
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
# TODO This slices tensors to copy to device, triggering recompilation.
|
||||
if batch_changed:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
@@ -447,6 +461,8 @@ class TPUModelRunner:
|
||||
# TODO: Support prompt logprobs.
|
||||
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
|
||||
num_reqs, self.max_num_reqs)
|
||||
# Indices at which we sample (positions of last token in the sequence).
|
||||
# Padded to avoid recompiling when `num_reqs` varies.
|
||||
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
||||
logits_indices = logits_indices.to(self.device)
|
||||
return attn_metadata, logits_indices
|
||||
@@ -576,7 +592,14 @@ class TPUModelRunner:
|
||||
# then the embedding layer is not included in the CUDA graph.
|
||||
input_ids = self.input_ids
|
||||
inputs_embeds = None
|
||||
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
# NOTE (NickLucche) here we sync with TPU: if there's any shape
|
||||
# mismatch in pre-processing, it will trigger a small recompilation
|
||||
# of the code thus far. Forward graph remains untouched.
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_sampling_metadata(sampling_metadata, logits_indices,
|
||||
num_reqs, self.device)
|
||||
# Run the decoder
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
@@ -585,12 +608,13 @@ class TPUModelRunner:
|
||||
kv_caches=self.kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
selected_token_ids = self.model.compute_logits(hidden_states,
|
||||
logits_indices, None)
|
||||
selected_token_ids = self.model.sample_from_hidden(
|
||||
hidden_states, tpu_sampling_metadata)
|
||||
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
|
||||
# Then, let's update the cache state.
|
||||
# Update the cache state concurrently. Code above will not block until
|
||||
# we use `selected_token_ids`. Add mark_step if post-processing changes
|
||||
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
|
||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||
assert req_id is not None
|
||||
@@ -607,7 +631,6 @@ class TPUModelRunner:
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
# num_reqs entries should be non-None
|
||||
assert all(
|
||||
req_id is not None for req_id in
|
||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||
@@ -620,6 +643,7 @@ class TPUModelRunner:
|
||||
max_gen_len = selected_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
valid_sampled_token_ids = selected_token_ids.tolist()
|
||||
|
||||
for i, req_state, seq_len in request_seq_lens:
|
||||
token_id = valid_sampled_token_ids[i][0]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
@@ -676,11 +700,8 @@ class TPUModelRunner:
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
|
||||
def _dummy_run(
|
||||
self,
|
||||
kv_caches,
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
@torch.no_grad()
|
||||
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
|
||||
@@ -729,32 +750,10 @@ class TPUModelRunner:
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
num_reqs = _get_padded_num_reqs_with_upper_limit(
|
||||
64, self.max_num_reqs)
|
||||
# NOTE(chengjiyao): In total, the compute_logits function utilizes a
|
||||
# compilation cache size of token_bucket_num multiplied by
|
||||
# req_bucket_num. This is acceptable, given the graph's relatively
|
||||
# small size.
|
||||
while True:
|
||||
logits_indices = torch.zeros(
|
||||
num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
torch._dynamo.mark_dynamic(logits_indices, 0)
|
||||
self.model.compute_logits(hidden_states, logits_indices, None)
|
||||
if num_reqs >= self.max_num_reqs:
|
||||
break
|
||||
num_reqs = _get_padded_num_reqs_with_upper_limit(
|
||||
num_reqs + 1, self.max_num_reqs)
|
||||
self.model(input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""Compile the model."""
|
||||
@@ -764,13 +763,51 @@ class TPUModelRunner:
|
||||
start = time.perf_counter()
|
||||
num_tokens = 16
|
||||
while True:
|
||||
self._dummy_run(self.kv_caches, num_tokens)
|
||||
logger.info(" -- num_tokens: %d", num_tokens)
|
||||
self._dummy_run(self.kv_caches, num_tokens)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
if num_tokens >= self.max_num_tokens:
|
||||
break
|
||||
num_tokens *= 2
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
|
||||
logger.info("Compiling sampling with different input shapes.")
|
||||
start = time.perf_counter()
|
||||
num_tokens = 16
|
||||
hsize = self.model_config.get_hidden_size()
|
||||
device = self.device
|
||||
# Compile sampling step for different model+sampler outputs in bucketed
|
||||
# n_tokens x max_num_reqs. Graph is really small so this is fine.
|
||||
while True:
|
||||
num_reqs_to_sample = MIN_NUM_SEQS
|
||||
dummy_hidden = torch.randn((num_tokens, hsize),
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
while True:
|
||||
# Default metadata is an all_greedy setup. But since the
|
||||
# `do_argmax` flag is a tensor, we still compile the full graph
|
||||
meta = self.input_batch.sampling_metadata
|
||||
indices = torch.zeros(
|
||||
num_reqs_to_sample,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
sampling_meta = TPUSupportedSamplingMetadata.\
|
||||
from_sampling_metadata(meta, indices,
|
||||
num_reqs_to_sample, device)
|
||||
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
||||
num_reqs_to_sample)
|
||||
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
|
||||
xm.mark_step()
|
||||
if num_reqs_to_sample >= self.max_num_reqs:
|
||||
break
|
||||
num_reqs_to_sample *= 2
|
||||
if num_tokens >= self.max_num_tokens:
|
||||
break
|
||||
num_tokens *= 2
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
|
||||
@@ -818,6 +855,13 @@ class ModelWrapperV1(nn.Module):
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.sampler = TPUSampler()
|
||||
|
||||
def sample(
|
||||
self, logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
|
||||
sampler_out = self.sampler(logits, sampling_metadata)
|
||||
return sampler_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -826,7 +870,7 @@ class ModelWrapperV1(nn.Module):
|
||||
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model and samples the next token.
|
||||
"""Executes the forward pass of the model.
|
||||
|
||||
Args:
|
||||
input_ids: The input token IDs of shape [num_tokens].
|
||||
@@ -837,7 +881,6 @@ class ModelWrapperV1(nn.Module):
|
||||
hidden_size]. It is used for multimodal models.
|
||||
"""
|
||||
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@@ -846,17 +889,33 @@ class ModelWrapperV1(nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def compute_logits(
|
||||
def sample_from_hidden(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits_indices: torch.Tensor,
|
||||
sampling_metadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
return selected_token_ids
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Sample with xla-friendly function. This function is to be traced
|
||||
separately from `forward` for lighter compilation overhead.
|
||||
"""
|
||||
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
|
||||
sample_hidden_states = \
|
||||
hidden_states[sampling_metadata.indices_do_sample]
|
||||
logits = self.compute_logits(sample_hidden_states)
|
||||
# Greedy sampling can't be run without branching the graph on Sampler.
|
||||
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
|
||||
# NOTE do_argmax is a scalar, this is just an optimized if/else.
|
||||
out_tokens = torch.where(sampling_metadata.do_argmax,
|
||||
torch.argmax(logits, dim=-1, keepdim=True),
|
||||
self.sample(logits, sampling_metadata)\
|
||||
.sampled_token_ids)
|
||||
return out_tokens
|
||||
|
||||
def compute_logits(self,
|
||||
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
return logits
|
||||
|
||||
def get_multimodal_embeddings(self, *args, **kwargs):
|
||||
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
||||
@@ -876,5 +935,5 @@ def _get_padded_token_len(x: int) -> int:
|
||||
|
||||
|
||||
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
|
||||
res = 64 if x <= 64 else 1 << (x - 1).bit_length()
|
||||
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
|
||||
return min(res, upper_limit)
|
||||
|
||||
Reference in New Issue
Block a user