[Model Runner V2] Fix error-handling (#35063)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-02-26 11:00:19 -08:00
committed by GitHub
parent 5e58bdc711
commit b6d5a17298

View File

@@ -227,6 +227,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
# For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: tuple | None = None
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len
@@ -388,6 +391,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert self.execute_model_state is not None
hidden_states, _, input_batch, _ = self.execute_model_state
self.execute_model_state = None
assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
@@ -1036,18 +1040,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states,
input_batch,
kv_connector_output,
) # type: ignore
)
return None
@torch.inference_mode()
def sample_tokens(
self, grammar_output: GrammarOutput | None
) -> AsyncOutput | ModelRunnerOutput | None:
assert self.execute_model_state is not None
if self.execute_model_state is None:
# The prior execute_model call must have failed.
return None
hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
self.execute_model_state
)
self.execute_model_state = None # type: ignore
self.execute_model_state = None
if not self.is_last_pp_rank:
# Non-last PP rank: hidden_states is None because this rank produced