[Misc][Refactor] Introduce ExecuteModelData (#4540)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -7,7 +7,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async)
|
||||
|
||||
@@ -72,18 +72,10 @@ class CPUExecutor(ExecutorBase):
|
||||
logger.info("# CPU blocks: %d", num_gpu_blocks)
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int) -> List[SamplerOutput]:
|
||||
output = self.driver_worker.execute_model(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
@@ -104,19 +96,10 @@ class CPUExecutor(ExecutorBase):
|
||||
class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_lookahead_slots: int,
|
||||
) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model)(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_lookahead_slots=num_lookahead_slots)
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
output = await make_async(self.driver_worker.execute_model
|
||||
)(execute_model_req=execute_model_req, )
|
||||
return output
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user