[ModelRunner V2][BugFix] Fix max_query_len calculation (#34167)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -149,13 +149,13 @@ def build_attn_metadata(
|
||||
num_tokens: int,
|
||||
query_start_loc_gpu: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
max_query_len: int,
|
||||
seq_lens: torch.Tensor,
|
||||
max_seq_len: int,
|
||||
block_tables: Sequence[torch.Tensor],
|
||||
slot_mappings: torch.Tensor,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
max_query_len = int(query_start_loc_cpu.max())
|
||||
seq_lens = seq_lens[:num_reqs]
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
|
||||
@@ -267,6 +267,7 @@ def prepare_inputs_to_capture(
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=num_tokens_per_req,
|
||||
seq_lens=input_buffers.seq_lens,
|
||||
max_seq_len=max_model_len,
|
||||
block_tables=input_block_tables,
|
||||
|
||||
@@ -274,6 +274,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens=input_batch.num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
|
||||
max_query_len=input_batch.num_scheduled_tokens.max().item(),
|
||||
seq_lens=input_batch.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
@@ -561,6 +562,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
|
||||
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
max_query_len = num_scheduled_tokens.max().item()
|
||||
|
||||
# Get prefill tokens.
|
||||
prepare_prefill_inputs(
|
||||
@@ -624,6 +626,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
seq_lens=self.input_buffers.seq_lens,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
|
||||
@@ -301,6 +301,7 @@ class EagleSpeculator:
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=1,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
|
||||
Reference in New Issue
Block a user