[Feature] Support recording expert indices for rollout router replay (#28284)
Signed-off-by: xhx1022 <1737006628@qq.com> Signed-off-by: Hongxin Xu <70438206+xhx1022@users.noreply.github.com> Signed-off-by: arlenxu <arlenxu@tencent.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: arlenxu <arlenxu@tencent.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -139,7 +140,7 @@ class EngineCoreOutput(
|
||||
trace_headers: Mapping[str, str] | None = None
|
||||
# The number of tokens with prefix cache hits.
|
||||
num_cached_tokens: int = 0
|
||||
|
||||
routed_experts: np.ndarray | None = None
|
||||
# The number of NaNs in logits.
|
||||
# A value greater than 0 indicates that the output is corrupted.
|
||||
num_nans_in_logits: int = 0
|
||||
|
||||
@@ -7,6 +7,7 @@ from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -213,6 +214,7 @@ class RequestState:
|
||||
finish_reason: FinishReason | None,
|
||||
stop_reason: int | str | None,
|
||||
kv_transfer_params: dict[str, Any] | None = None,
|
||||
routed_experts: np.ndarray | None = None,
|
||||
) -> RequestOutput | PoolingRequestOutput | None:
|
||||
finished = finish_reason is not None
|
||||
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
|
||||
@@ -253,7 +255,9 @@ class RequestState:
|
||||
finished,
|
||||
)
|
||||
|
||||
output = self._new_completion_output(new_token_ids, finish_reason, stop_reason)
|
||||
output = self._new_completion_output(
|
||||
new_token_ids, finish_reason, stop_reason, routed_experts
|
||||
)
|
||||
|
||||
if self.parent_req is None:
|
||||
outputs = [output]
|
||||
@@ -316,6 +320,7 @@ class RequestState:
|
||||
token_ids: list[int],
|
||||
finish_reason: FinishReason | None,
|
||||
stop_reason: int | str | None,
|
||||
routed_experts: np.ndarray | None = None,
|
||||
) -> CompletionOutput:
|
||||
assert self.detokenizer is not None
|
||||
assert self.logprobs_processor is not None
|
||||
@@ -336,6 +341,7 @@ class RequestState:
|
||||
index=self.request_index,
|
||||
text=text,
|
||||
token_ids=token_ids,
|
||||
routed_experts=routed_experts,
|
||||
logprobs=logprobs,
|
||||
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
|
||||
finish_reason=str(finish_reason) if finished else None,
|
||||
@@ -527,6 +533,7 @@ class OutputProcessor:
|
||||
finish_reason = engine_core_output.finish_reason
|
||||
stop_reason = engine_core_output.stop_reason
|
||||
kv_transfer_params = engine_core_output.kv_transfer_params
|
||||
routed_experts = engine_core_output.routed_experts
|
||||
req_state.num_cached_tokens = engine_core_output.num_cached_tokens
|
||||
req_state.is_prefilling = False
|
||||
|
||||
@@ -552,6 +559,7 @@ class OutputProcessor:
|
||||
finish_reason,
|
||||
stop_reason,
|
||||
kv_transfer_params,
|
||||
routed_experts,
|
||||
):
|
||||
if req_state.queue is not None:
|
||||
# AsyncLLM: put into queue for handling by generate().
|
||||
|
||||
Reference in New Issue
Block a user