[Misc][Refactor] Introduce ExecuteModelData (#4540)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass, fields
|
||||
from itertools import count
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
from unittest.mock import MagicMock
|
||||
@@ -16,50 +15,10 @@ from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteModelData:
|
||||
"""Helper data structure which facilitates cleaner tests.
|
||||
"""
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
blocks_to_swap_in: Dict[int, int]
|
||||
blocks_to_swap_out: Dict[int, int]
|
||||
blocks_to_copy: Dict[int, List[int]]
|
||||
|
||||
def to_dict(self):
|
||||
return dict(
|
||||
(field.name, getattr(self, field.name)) for field in fields(self))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
cleaned = dict((field.name, d[field.name]) for field in fields(cls))
|
||||
return cls(**cleaned)
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
|
||||
|
||||
def create_execute_model_data(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
||||
blocks_to_copy: Optional[Dict[int, int]] = None,
|
||||
) -> ExecuteModelData:
|
||||
if blocks_to_swap_in is None:
|
||||
blocks_to_swap_in = {}
|
||||
if blocks_to_swap_out is None:
|
||||
blocks_to_swap_out = {}
|
||||
if blocks_to_copy is None:
|
||||
blocks_to_copy = {}
|
||||
|
||||
return ExecuteModelData(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
|
||||
def mock_worker(cls=None,
|
||||
vocab_size: int = 30_000,
|
||||
max_model_len: int = 2048,
|
||||
@@ -258,8 +217,7 @@ def create_batch(batch_size,
|
||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
|
||||
block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids), )
|
||||
return execute_model_data, prompts, prev_output_tokens
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||
prev_output_tokens, seq_ids)
|
||||
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||
|
||||
Reference in New Issue
Block a user