Add seq_ids to input metadata
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -120,6 +120,7 @@ class Worker:
|
||||
padded_block_tables, dtype=int, device=self.device)
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
seq_ids=prompt_seq_ids + generation_seq_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
context_lens=context_lens_tensor,
|
||||
@@ -128,7 +129,6 @@ class Worker:
|
||||
)
|
||||
return tokens_tensor, positions_tensor, input_metadata
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_stage(
|
||||
self,
|
||||
@@ -139,7 +139,7 @@ class Worker:
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, int],
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, Dict[int, Tuple[int, int]]]:
|
||||
# Issue cache operations.
|
||||
command_issued = False
|
||||
if blocks_to_swap_in:
|
||||
|
||||
Reference in New Issue
Block a user