[core] [3/N] multi-step args and sequence.py (#7452)
This commit is contained in:
@@ -847,7 +847,8 @@ class SchedulerConfig:
|
||||
delay_factor: float = 0.0,
|
||||
enable_chunked_prefill: bool = False,
|
||||
embedding_mode: Optional[bool] = False,
|
||||
preemption_mode: Optional[str] = None) -> None:
|
||||
preemption_mode: Optional[str] = None,
|
||||
num_scheduler_steps: int = 1) -> None:
|
||||
if max_num_batched_tokens is not None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
else:
|
||||
@@ -876,6 +877,7 @@ class SchedulerConfig:
|
||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||
self.embedding_mode = embedding_mode
|
||||
self.preemption_mode = preemption_mode
|
||||
self.num_scheduler_steps = num_scheduler_steps
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
@@ -901,6 +903,16 @@ class SchedulerConfig:
|
||||
f"({self.num_lookahead_slots}) must be greater than or "
|
||||
"equal to 0.")
|
||||
|
||||
if self.num_scheduler_steps < 1:
|
||||
raise ValueError(
|
||||
"num_scheduler_steps "
|
||||
f"({self.num_scheduler_steps}) must be greater than or "
|
||||
"equal to 1.")
|
||||
|
||||
@property
|
||||
def is_multi_step(self) -> bool:
|
||||
return self.num_scheduler_steps > 1
|
||||
|
||||
|
||||
class DeviceConfig:
|
||||
device: Optional[torch.device]
|
||||
|
||||
Reference in New Issue
Block a user