[V1] Remove num_input_tokens from attn_metadata (#17193)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-04-30 00:28:41 +08:00
committed by GitHub
parent 2ef5d106bb
commit 24e6ad3f16
6 changed files with 14 additions and 21 deletions

View File

@@ -1036,7 +1036,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
attn_metadata.num_input_tokens = num_input_tokens
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
@@ -1088,7 +1087,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
output = self.model(
input_ids=input_ids,
positions=positions,

View File

@@ -769,7 +769,10 @@ class TPUModelRunner:
xm.mark_step()
num_reqs = self.input_batch.num_reqs
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=scheduler_output.total_num_scheduled_tokens):
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,