[Model Runner V2] Fix error-handling (#35063)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -227,6 +227,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# KV Connector if configured.
|
||||
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
|
||||
|
||||
# For transferring state from execute_model to subsequent sample_tokens call.
|
||||
self.execute_model_state: tuple | None = None
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
self.req_states.max_model_len = max_model_len
|
||||
@@ -388,6 +391,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, _, input_batch, _ = self.execute_model_state
|
||||
self.execute_model_state = None
|
||||
assert hidden_states is not None # Last PP rank always has hidden_states
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
@@ -1036,18 +1040,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
aux_hidden_states,
|
||||
input_batch,
|
||||
kv_connector_output,
|
||||
) # type: ignore
|
||||
)
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample_tokens(
|
||||
self, grammar_output: GrammarOutput | None
|
||||
) -> AsyncOutput | ModelRunnerOutput | None:
|
||||
assert self.execute_model_state is not None
|
||||
if self.execute_model_state is None:
|
||||
# The prior execute_model call must have failed.
|
||||
return None
|
||||
hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
|
||||
self.execute_model_state
|
||||
)
|
||||
self.execute_model_state = None # type: ignore
|
||||
self.execute_model_state = None
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
# Non-last PP rank: hidden_states is None because this rank produced
|
||||
|
||||
Reference in New Issue
Block a user