[TPU][V1] Fix Sampler recompilation (#15309)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-03-25 21:43:54 +01:00
committed by GitHub
parent e977c11111
commit a0dd7dcd49
2 changed files with 83 additions and 126 deletions

View File

@@ -279,9 +279,6 @@ class TPUModelRunner:
req_data.num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
@@ -300,9 +297,6 @@ class TPUModelRunner:
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
# TODO This slices tensors to copy to device, triggering recompilation.
if batch_changed:
self.input_batch.refresh_sampling_metadata()
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def get_model(self) -> nn.Module:
@@ -597,14 +591,12 @@ class TPUModelRunner:
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids
inputs_embeds = None
sampling_metadata = self.input_batch.sampling_metadata
num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: if there's any shape
# mismatch in pre-processing, it will trigger a small recompilation
# of the code thus far. Forward graph remains untouched.
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
# are copied to device in chunks of pre-compiled padded shape to
# avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_sampling_metadata(sampling_metadata, logits_indices,
num_reqs, self.device)
from_input_batch(self.input_batch, logits_indices)
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
@@ -797,21 +789,19 @@ class TPUModelRunner:
device=device,
dtype=torch.bfloat16)
while True:
# Default metadata is an all_greedy setup. But since the
# `do_argmax` flag is a tensor, we still compile the full graph
meta = self.input_batch.sampling_metadata
indices = torch.zeros(
num_reqs_to_sample,
dtype=torch.int32,
device=device,
)
xm.mark_step()
sampling_meta = TPUSupportedSamplingMetadata.\
from_sampling_metadata(meta, indices,
num_reqs_to_sample, device)
from_input_batch(self.input_batch, indices)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample)
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
xm.mark_step()
out = self.model.sample_from_hidden(dummy_hidden,
sampling_meta)
out = out.cpu()
if num_reqs_to_sample >= self.max_num_reqs:
break
num_reqs_to_sample *= 2
@@ -910,6 +900,7 @@ class ModelWrapperV1(nn.Module):
return hidden_states
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_hidden(
self,
hidden_states: torch.Tensor,
@@ -923,10 +914,9 @@ class ModelWrapperV1(nn.Module):
sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states)
# Greedy sampling can't be run without branching the graph on Sampler.
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
# NOTE do_argmax is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.do_argmax,
# Optimized greedy sampling branch, tracing both paths in a single pass
# NOTE all_greedy is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.all_greedy,
torch.argmax(logits, dim=-1, keepdim=True),
self.sample(logits, sampling_metadata)\
.sampled_token_ids)