[TPU][V1] Add support for top-logprobs (#17072)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-05-05 23:20:15 +02:00
committed by GitHub
parent 9765940824
commit 5941e0b7ea
4 changed files with 105 additions and 17 deletions

View File

@@ -791,8 +791,18 @@ class TPUModelRunner:
arange)
selected_token_ids = self.sample_from_logits(logits,
tpu_sampling_metadata)
# NOTE (NickLucche) Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. We can't enforce it due
# to recompilations outside torch.compiled code, so just make sure
# `sample_from_logits` does not modify the logits in-place.
logprobs = self.gather_logprobs(logits, selected_token_ids) \
if tpu_sampling_metadata.logprobs else None
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
logprobs_lists = logprobs.tolists() \
if tpu_sampling_metadata.logprobs else None
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
@@ -862,7 +872,7 @@ class TPUModelRunner:
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=None,
logprobs=None,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
)
@@ -1121,6 +1131,22 @@ class TPUModelRunner:
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("sample_from_logits")
def _precompile_gather_logprobs(self) -> None:
logger.info("Compiling gather_logprobs with different input shapes.")
start = time.perf_counter()
for num_reqs in self.num_reqs_paddings:
dummy_logits = torch.zeros((num_reqs, self.vocab_size),
device=self.device,
dtype=self._hidden_states_dtype)
dummy_tokens = torch.zeros((num_reqs, 1),
dtype=torch.int64).to(self.device)
self.gather_logprobs(dummy_logits, dummy_tokens)
logger.info(" -- num_seqs: %d", num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
self._update_num_xla_graphs("gather_logprobs")
def capture_model(self) -> None:
"""
Precompile all the subgraphs with possible input shapes.
@@ -1131,6 +1157,7 @@ class TPUModelRunner:
self._precompile_compute_logits()
self._precompile_structured_decoding()
self._precompile_sample_from_logits()
self._precompile_gather_logprobs()
def profile_run(
self,
@@ -1254,6 +1281,10 @@ class TPUModelRunner:
def sample_from_logits(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
if sampling_metadata.all_greedy:
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
else:
@@ -1261,6 +1292,20 @@ class TPUModelRunner:
sampling_metadata).sampled_token_ids
return out_tokens
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def gather_logprobs(self, logits: torch.Tensor,
sampled_tokens: torch.Tensor) -> LogprobsTensors:
"""
Gather the top_logprobs with corresponding tokens. Use a fixed number
of logprobs as an alternative to having multiple pre-compiled graphs.
Select the number of logprobs actually demanded by each request on CPU.
"""
logprobs = self.sampler.compute_logprobs(logits)
return self.sampler.gather_logprobs(
logprobs,
self.model_config.max_logprobs,
token_ids=sampled_tokens.squeeze(-1))
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def structured_decode(self, require_struct_decoding: torch.Tensor,
grammar_bitmask: torch.Tensor, logits: torch.Tensor,