[ModelRunner V2][BugFix] Fix max_query_len calculation (#34167)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-02-09 13:47:17 -08:00
committed by GitHub
parent bb9f97308d
commit e7e52781ff
4 changed files with 6 additions and 1 deletions

View File

@@ -149,13 +149,13 @@ def build_attn_metadata(
num_tokens: int,
query_start_loc_gpu: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
max_query_len: int,
seq_lens: torch.Tensor,
max_seq_len: int,
block_tables: Sequence[torch.Tensor],
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
max_query_len = int(query_start_loc_cpu.max())
seq_lens = seq_lens[:num_reqs]
attn_metadata: dict[str, Any] = {}

View File

@@ -267,6 +267,7 @@ def prepare_inputs_to_capture(
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=num_tokens_per_req,
seq_lens=input_buffers.seq_lens,
max_seq_len=max_model_len,
block_tables=input_block_tables,

View File

@@ -274,6 +274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
max_query_len=input_batch.num_scheduled_tokens.max().item(),
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
@@ -561,6 +562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
max_query_len = num_scheduled_tokens.max().item()
# Get prefill tokens.
prepare_prefill_inputs(
@@ -624,6 +626,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=self.input_buffers.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,

View File

@@ -301,6 +301,7 @@ class EagleSpeculator:
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
max_seq_len=self.max_model_len,
block_tables=block_tables,