[TPU][V1] Add support for top-logprobs (#17072)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -791,8 +791,18 @@ class TPUModelRunner:
|
||||
arange)
|
||||
selected_token_ids = self.sample_from_logits(logits,
|
||||
tpu_sampling_metadata)
|
||||
|
||||
# NOTE (NickLucche) Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs. We can't enforce it due
|
||||
# to recompilations outside torch.compiled code, so just make sure
|
||||
# `sample_from_logits` does not modify the logits in-place.
|
||||
logprobs = self.gather_logprobs(logits, selected_token_ids) \
|
||||
if tpu_sampling_metadata.logprobs else None
|
||||
|
||||
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
logprobs_lists = logprobs.tolists() \
|
||||
if tpu_sampling_metadata.logprobs else None
|
||||
|
||||
# Update the cache state concurrently. Code above will not block until
|
||||
# we use `selected_token_ids`. Add mark_step if post-processing changes
|
||||
@@ -862,7 +872,7 @@ class TPUModelRunner:
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
)
|
||||
|
||||
@@ -1121,6 +1131,22 @@ class TPUModelRunner:
|
||||
logger.info("Compilation finished in %.2f [secs].", end - start)
|
||||
self._update_num_xla_graphs("sample_from_logits")
|
||||
|
||||
def _precompile_gather_logprobs(self) -> None:
|
||||
logger.info("Compiling gather_logprobs with different input shapes.")
|
||||
start = time.perf_counter()
|
||||
for num_reqs in self.num_reqs_paddings:
|
||||
dummy_logits = torch.zeros((num_reqs, self.vocab_size),
|
||||
device=self.device,
|
||||
dtype=self._hidden_states_dtype)
|
||||
dummy_tokens = torch.zeros((num_reqs, 1),
|
||||
dtype=torch.int64).to(self.device)
|
||||
self.gather_logprobs(dummy_logits, dummy_tokens)
|
||||
logger.info(" -- num_seqs: %d", num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in %.2f [secs].", end - start)
|
||||
self._update_num_xla_graphs("gather_logprobs")
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""
|
||||
Precompile all the subgraphs with possible input shapes.
|
||||
@@ -1131,6 +1157,7 @@ class TPUModelRunner:
|
||||
self._precompile_compute_logits()
|
||||
self._precompile_structured_decoding()
|
||||
self._precompile_sample_from_logits()
|
||||
self._precompile_gather_logprobs()
|
||||
|
||||
def profile_run(
|
||||
self,
|
||||
@@ -1254,6 +1281,10 @@ class TPUModelRunner:
|
||||
def sample_from_logits(
|
||||
self, logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata) -> torch.Tensor:
|
||||
"""
|
||||
Sample with xla-friendly function. This function is to be traced
|
||||
separately from `forward` for lighter compilation overhead.
|
||||
"""
|
||||
if sampling_metadata.all_greedy:
|
||||
out_tokens = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
else:
|
||||
@@ -1261,6 +1292,20 @@ class TPUModelRunner:
|
||||
sampling_metadata).sampled_token_ids
|
||||
return out_tokens
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def gather_logprobs(self, logits: torch.Tensor,
|
||||
sampled_tokens: torch.Tensor) -> LogprobsTensors:
|
||||
"""
|
||||
Gather the top_logprobs with corresponding tokens. Use a fixed number
|
||||
of logprobs as an alternative to having multiple pre-compiled graphs.
|
||||
Select the number of logprobs actually demanded by each request on CPU.
|
||||
"""
|
||||
logprobs = self.sampler.compute_logprobs(logits)
|
||||
return self.sampler.gather_logprobs(
|
||||
logprobs,
|
||||
self.model_config.max_logprobs,
|
||||
token_ids=sampled_tokens.squeeze(-1))
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def structured_decode(self, require_struct_decoding: torch.Tensor,
|
||||
grammar_bitmask: torch.Tensor, logits: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user