[TPU][V1] Fix Sampler recompilation (#15309)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -279,9 +279,6 @@ class TPUModelRunner:
|
||||
req_data.num_computed_tokens)
|
||||
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
||||
req_index)
|
||||
# Check if the batch has changed. If not, we can skip copying the
|
||||
# sampling metadata from CPU to GPU.
|
||||
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
@@ -300,9 +297,6 @@ class TPUModelRunner:
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
# TODO This slices tensors to copy to device, triggering recompilation.
|
||||
if batch_changed:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
@@ -597,14 +591,12 @@ class TPUModelRunner:
|
||||
# then the embedding layer is not included in the CUDA graph.
|
||||
input_ids = self.input_ids
|
||||
inputs_embeds = None
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
# NOTE (NickLucche) here we sync with TPU: if there's any shape
|
||||
# mismatch in pre-processing, it will trigger a small recompilation
|
||||
# of the code thus far. Forward graph remains untouched.
|
||||
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
|
||||
# are copied to device in chunks of pre-compiled padded shape to
|
||||
# avoid recompilations.
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_sampling_metadata(sampling_metadata, logits_indices,
|
||||
num_reqs, self.device)
|
||||
from_input_batch(self.input_batch, logits_indices)
|
||||
# Run the decoder
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
@@ -797,21 +789,19 @@ class TPUModelRunner:
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
while True:
|
||||
# Default metadata is an all_greedy setup. But since the
|
||||
# `do_argmax` flag is a tensor, we still compile the full graph
|
||||
meta = self.input_batch.sampling_metadata
|
||||
indices = torch.zeros(
|
||||
num_reqs_to_sample,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
xm.mark_step()
|
||||
sampling_meta = TPUSupportedSamplingMetadata.\
|
||||
from_sampling_metadata(meta, indices,
|
||||
num_reqs_to_sample, device)
|
||||
from_input_batch(self.input_batch, indices)
|
||||
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
||||
num_reqs_to_sample)
|
||||
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
|
||||
xm.mark_step()
|
||||
out = self.model.sample_from_hidden(dummy_hidden,
|
||||
sampling_meta)
|
||||
out = out.cpu()
|
||||
if num_reqs_to_sample >= self.max_num_reqs:
|
||||
break
|
||||
num_reqs_to_sample *= 2
|
||||
@@ -910,6 +900,7 @@ class ModelWrapperV1(nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def sample_from_hidden(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -923,10 +914,9 @@ class ModelWrapperV1(nn.Module):
|
||||
sample_hidden_states = \
|
||||
hidden_states[sampling_metadata.indices_do_sample]
|
||||
logits = self.compute_logits(sample_hidden_states)
|
||||
# Greedy sampling can't be run without branching the graph on Sampler.
|
||||
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
|
||||
# NOTE do_argmax is a scalar, this is just an optimized if/else.
|
||||
out_tokens = torch.where(sampling_metadata.do_argmax,
|
||||
# Optimized greedy sampling branch, tracing both paths in a single pass
|
||||
# NOTE all_greedy is a scalar, this is just an optimized if/else.
|
||||
out_tokens = torch.where(sampling_metadata.all_greedy,
|
||||
torch.argmax(logits, dim=-1, keepdim=True),
|
||||
self.sample(logits, sampling_metadata)\
|
||||
.sampled_token_ids)
|
||||
|
||||
Reference in New Issue
Block a user