[Feature] Support recording expert indices for rollout router replay (#28284)

Signed-off-by: xhx1022 <1737006628@qq.com>
Signed-off-by: Hongxin Xu <70438206+xhx1022@users.noreply.github.com>
Signed-off-by: arlenxu <arlenxu@tencent.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: arlenxu <arlenxu@tencent.com>
This commit is contained in:
Hongxin Xu
2026-01-12 22:23:04 +08:00
committed by GitHub
parent 0565f1fdec
commit 49e6b86c91
11 changed files with 463 additions and 3 deletions

View File

@@ -7,6 +7,7 @@ from collections.abc import Mapping
from typing import Any
import msgspec
import numpy as np
import torch
from vllm.lora.request import LoRARequest
@@ -139,7 +140,7 @@ class EngineCoreOutput(
trace_headers: Mapping[str, str] | None = None
# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0
routed_experts: np.ndarray | None = None
# The number of NaNs in logits.
# A value greater than 0 indicates that the output is corrupted.
num_nans_in_logits: int = 0

View File

@@ -7,6 +7,7 @@ from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, cast
import numpy as np
import torch
from vllm.lora.request import LoRARequest
@@ -213,6 +214,7 @@ class RequestState:
finish_reason: FinishReason | None,
stop_reason: int | str | None,
kv_transfer_params: dict[str, Any] | None = None,
routed_experts: np.ndarray | None = None,
) -> RequestOutput | PoolingRequestOutput | None:
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
@@ -253,7 +255,9 @@ class RequestState:
finished,
)
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
output = self._new_completion_output(
new_token_ids, finish_reason, stop_reason, routed_experts
)
if self.parent_req is None:
outputs = [output]
@@ -316,6 +320,7 @@ class RequestState:
token_ids: list[int],
finish_reason: FinishReason | None,
stop_reason: int | str | None,
routed_experts: np.ndarray | None = None,
) -> CompletionOutput:
assert self.detokenizer is not None
assert self.logprobs_processor is not None
@@ -336,6 +341,7 @@ class RequestState:
index=self.request_index,
text=text,
token_ids=token_ids,
routed_experts=routed_experts,
logprobs=logprobs,
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
finish_reason=str(finish_reason) if finished else None,
@@ -527,6 +533,7 @@ class OutputProcessor:
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
routed_experts = engine_core_output.routed_experts
req_state.num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False
@@ -552,6 +559,7 @@ class OutputProcessor:
finish_reason,
stop_reason,
kv_transfer_params,
routed_experts,
):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().