[core] [3/N] multi-step args and sequence.py (#7452)

This commit is contained in:
William Lin
2024-08-14 12:32:45 -07:00
committed by GitHub
parent 3f674a49b5
commit 2ecf7b1757
4 changed files with 100 additions and 5 deletions

View File

@@ -847,7 +847,8 @@ class SchedulerConfig:
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
preemption_mode: Optional[str] = None) -> None:
preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
@@ -876,6 +877,7 @@ class SchedulerConfig:
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps
self._verify_args()
def _verify_args(self) -> None:
@@ -901,6 +903,16 @@ class SchedulerConfig:
f"({self.num_lookahead_slots}) must be greater than or "
"equal to 0.")
if self.num_scheduler_steps < 1:
raise ValueError(
"num_scheduler_steps "
f"({self.num_scheduler_steps}) must be greater than or "
"equal to 1.")
@property
def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1
class DeviceConfig:
device: Optional[torch.device]