Add seq_ids to input metadata

This commit is contained in:
Woosuk Kwon
2023-02-23 09:25:01 +00:00
parent 4f6f4967f6
commit 343cea3dbc
2 changed files with 6 additions and 3 deletions

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union
import torch
@@ -120,6 +120,7 @@ class Worker:
padded_block_tables, dtype=int, device=self.device)
input_metadata = InputMetadata(
seq_ids=prompt_seq_ids + generation_seq_ids,
prompt_lens=prompt_lens,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
@@ -128,7 +129,6 @@ class Worker:
)
return tokens_tensor, positions_tensor, input_metadata
@torch.inference_mode()
def execute_stage(
self,
@@ -139,7 +139,7 @@ class Worker:
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, int],
) -> torch.Tensor:
) -> Union[torch.Tensor, Dict[int, Tuple[int, int]]]:
# Issue cache operations.
command_issued = False
if blocks_to_swap_in: