[V1][Spec decode] Move drafter to model runner (#13363)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -203,6 +203,7 @@ def test_schedule_partial_requests():
|
|||||||
req_ids=[request.request_id for request in requests],
|
req_ids=[request.request_id for request in requests],
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||||
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
)
|
)
|
||||||
@@ -259,6 +260,7 @@ def test_stop_via_update_from_output():
|
|||||||
sampled_token_ids=[[EOS_TOKEN_ID],
|
sampled_token_ids=[[EOS_TOKEN_ID],
|
||||||
[10,
|
[10,
|
||||||
11]], # First request hits EOS, second continues
|
11]], # First request hits EOS, second continues
|
||||||
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={})
|
prompt_logprobs_dict={})
|
||||||
|
|
||||||
@@ -307,6 +309,7 @@ def test_stop_via_update_from_output():
|
|||||||
},
|
},
|
||||||
sampled_token_ids=[[10, 42, 12],
|
sampled_token_ids=[[10, 42, 12],
|
||||||
[13, 14]], # First request hits stop token
|
[13, 14]], # First request hits stop token
|
||||||
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={})
|
prompt_logprobs_dict={})
|
||||||
|
|
||||||
@@ -354,6 +357,7 @@ def test_stop_via_update_from_output():
|
|||||||
},
|
},
|
||||||
sampled_token_ids=[[10, 11, 12],
|
sampled_token_ids=[[10, 11, 12],
|
||||||
[13]], # First request exceeds max_tokens
|
[13]], # First request exceeds max_tokens
|
||||||
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={})
|
prompt_logprobs_dict={})
|
||||||
|
|
||||||
@@ -394,6 +398,7 @@ def test_stop_via_update_from_output():
|
|||||||
req_ids=[requests[0].request_id],
|
req_ids=[requests[0].request_id],
|
||||||
req_id_to_index={requests[0].request_id: 0},
|
req_id_to_index={requests[0].request_id: 0},
|
||||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||||
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={})
|
prompt_logprobs_dict={})
|
||||||
|
|
||||||
@@ -434,6 +439,7 @@ def test_schedule_concurrent_batches():
|
|||||||
req_ids=[requests[0].request_id],
|
req_ids=[requests[0].request_id],
|
||||||
req_id_to_index={requests[0].request_id: 0},
|
req_id_to_index={requests[0].request_id: 0},
|
||||||
sampled_token_ids=[[0]],
|
sampled_token_ids=[[0]],
|
||||||
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
)
|
)
|
||||||
@@ -450,6 +456,7 @@ def test_schedule_concurrent_batches():
|
|||||||
req_ids=[requests[1].request_id],
|
req_ids=[requests[1].request_id],
|
||||||
req_id_to_index={requests[1].request_id: 0},
|
req_id_to_index={requests[1].request_id: 0},
|
||||||
sampled_token_ids=[[0]],
|
sampled_token_ids=[[0]],
|
||||||
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -474,6 +474,7 @@ class Scheduler:
|
|||||||
model_runner_output: "ModelRunnerOutput",
|
model_runner_output: "ModelRunnerOutput",
|
||||||
) -> EngineCoreOutputs:
|
) -> EngineCoreOutputs:
|
||||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||||
|
spec_token_ids = model_runner_output.spec_token_ids
|
||||||
logprobs = model_runner_output.logprobs
|
logprobs = model_runner_output.logprobs
|
||||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
@@ -530,13 +531,9 @@ class Scheduler:
|
|||||||
self.encoder_cache_manager.free_encoder_input(
|
self.encoder_cache_manager.free_encoder_input(
|
||||||
request, input_id)
|
request, input_id)
|
||||||
|
|
||||||
if request.num_computed_tokens >= request.num_tokens:
|
# Add newly generated spec token ids to the request.
|
||||||
# Clear the spec tokens as the request has generated
|
if spec_token_ids is not None:
|
||||||
# a new token. Here, We assume all spec tokens are verified
|
request.spec_token_ids = spec_token_ids[req_index]
|
||||||
# if we perform speculative decoding for this request.
|
|
||||||
# Therefore, we can clear all spec tokens after
|
|
||||||
# the generation step.
|
|
||||||
request.clear_spec_tokens()
|
|
||||||
|
|
||||||
# Get prompt logprobs for this request.
|
# Get prompt logprobs for this request.
|
||||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from vllm.v1.executor.abstract import Executor
|
|||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -86,15 +85,6 @@ class EngineCore:
|
|||||||
self.batch_queue_size)
|
self.batch_queue_size)
|
||||||
self.batch_queue = queue.Queue(self.batch_queue_size)
|
self.batch_queue = queue.Queue(self.batch_queue_size)
|
||||||
|
|
||||||
# Setup speculative decode.
|
|
||||||
# TODO: find a better way to check if we are using ngram.
|
|
||||||
self.use_spec_decode = False
|
|
||||||
if self.scheduler.speculative_config:
|
|
||||||
assert self.scheduler.speculative_config.ngram_prompt_lookup_min \
|
|
||||||
, "Only ngram spec decode is supported in V1."
|
|
||||||
self.proposer = NgramProposer()
|
|
||||||
self.use_spec_decode = True
|
|
||||||
|
|
||||||
def _initialize_kv_caches(self,
|
def _initialize_kv_caches(self,
|
||||||
vllm_config: VllmConfig) -> Tuple[int, int]:
|
vllm_config: VllmConfig) -> Tuple[int, int]:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@@ -158,9 +148,6 @@ class EngineCore:
|
|||||||
return EngineCoreOutputs(
|
return EngineCoreOutputs(
|
||||||
outputs=[], scheduler_stats=self.scheduler.make_stats())
|
outputs=[], scheduler_stats=self.scheduler.make_stats())
|
||||||
|
|
||||||
if self.use_spec_decode:
|
|
||||||
self.propose_tokens()
|
|
||||||
|
|
||||||
scheduler_output = self.scheduler.schedule()
|
scheduler_output = self.scheduler.schedule()
|
||||||
output = self.model_executor.execute_model(scheduler_output)
|
output = self.model_executor.execute_model(scheduler_output)
|
||||||
engine_core_outputs = self.scheduler.update_from_output(
|
engine_core_outputs = self.scheduler.update_from_output(
|
||||||
@@ -221,23 +208,6 @@ class EngineCore:
|
|||||||
def profile(self, is_start: bool = True):
|
def profile(self, is_start: bool = True):
|
||||||
self.model_executor.profile(is_start)
|
self.model_executor.profile(is_start)
|
||||||
|
|
||||||
def propose_tokens(self):
|
|
||||||
assert self.scheduler.speculative_config is not None
|
|
||||||
for req in self.scheduler.running:
|
|
||||||
# Ignore requests that are doing chunked prefill.
|
|
||||||
if req.num_computed_tokens < req.num_tokens - 1:
|
|
||||||
continue
|
|
||||||
# Ignore requests that already have spec tokens.
|
|
||||||
if req.spec_token_ids:
|
|
||||||
continue
|
|
||||||
spec_tokens = self.proposer.propose(
|
|
||||||
req.all_token_ids,
|
|
||||||
self.scheduler.speculative_config.ngram_prompt_lookup_min,
|
|
||||||
self.scheduler.speculative_config.num_speculative_tokens,
|
|
||||||
)
|
|
||||||
if spec_tokens:
|
|
||||||
req.append_spec_token_ids(spec_tokens)
|
|
||||||
|
|
||||||
def reset_prefix_cache(self):
|
def reset_prefix_cache(self):
|
||||||
self.scheduler.reset_prefix_cache()
|
self.scheduler.reset_prefix_cache()
|
||||||
|
|
||||||
|
|||||||
@@ -67,6 +67,9 @@ class ModelRunnerOutput:
|
|||||||
# each request due to speculative/jump decoding.
|
# each request due to speculative/jump decoding.
|
||||||
sampled_token_ids: List[List[int]]
|
sampled_token_ids: List[List[int]]
|
||||||
|
|
||||||
|
# num_reqs x num_spec_tokens
|
||||||
|
spec_token_ids: Optional[List[List[int]]]
|
||||||
|
|
||||||
# [num_reqs, max_num_logprobs + 1]
|
# [num_reqs, max_num_logprobs + 1]
|
||||||
# [num_reqs, max_num_logprobs + 1]
|
# [num_reqs, max_num_logprobs + 1]
|
||||||
# [num_reqs]
|
# [num_reqs]
|
||||||
|
|||||||
@@ -104,18 +104,6 @@ class Request:
|
|||||||
self._output_token_ids.extend(token_ids)
|
self._output_token_ids.extend(token_ids)
|
||||||
self._all_token_ids.extend(token_ids)
|
self._all_token_ids.extend(token_ids)
|
||||||
|
|
||||||
def append_spec_token_ids(
|
|
||||||
self,
|
|
||||||
token_ids: Union[int, List[int]],
|
|
||||||
) -> None:
|
|
||||||
if isinstance(token_ids, int):
|
|
||||||
self.spec_token_ids.append(token_ids)
|
|
||||||
else:
|
|
||||||
self.spec_token_ids.extend(token_ids)
|
|
||||||
|
|
||||||
def clear_spec_tokens(self) -> None:
|
|
||||||
self.spec_token_ids.clear()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_tokens(self) -> int:
|
def num_tokens(self) -> int:
|
||||||
return len(self._all_token_ids)
|
return len(self._all_token_ids)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from vllm.v1.utils import ConstantList
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class NgramProposer:
|
class NgramProposer:
|
||||||
@@ -9,8 +9,12 @@ class NgramProposer:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def propose(self, context_token_ids: ConstantList[int], n: int,
|
def propose(
|
||||||
k: int) -> Optional[List[int]]:
|
self,
|
||||||
|
context_token_ids: np.ndarray,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
"""Proposes the next sequence of tokens based on n-gram pattern
|
"""Proposes the next sequence of tokens based on n-gram pattern
|
||||||
matching in the context. The function finds matches of the last n
|
matching in the context. The function finds matches of the last n
|
||||||
tokens in the previous context, and returns k tokens that followed
|
tokens in the previous context, and returns k tokens that followed
|
||||||
@@ -25,8 +29,8 @@ class NgramProposer:
|
|||||||
the maximum amount of tokens until the end.
|
the maximum amount of tokens until the end.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[int]: The sequence of tokens that followed
|
np.ndarray: The sequence of tokens that followed
|
||||||
the matched n-gram in the context.
|
the matched n-gram in the context.
|
||||||
None: If no matching n-gram pattern is found.
|
None: If no matching n-gram pattern is found.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -66,9 +70,12 @@ class NgramProposer:
|
|||||||
return lps
|
return lps
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _find_subarray_kmp(context_token_ids: ConstantList[int], n: int,
|
def _find_subarray_kmp(
|
||||||
k: int) -> Optional[List[int]]:
|
context_token_ids: np.ndarray,
|
||||||
context_len = len(context_token_ids)
|
n: int,
|
||||||
|
k: int,
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
context_len = context_token_ids.shape[0]
|
||||||
assert n > 0
|
assert n > 0
|
||||||
|
|
||||||
pattern = context_token_ids[-n:]
|
pattern = context_token_ids[-n:]
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class InputBatch:
|
|||||||
)
|
)
|
||||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
|
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
|
||||||
|
|
||||||
@@ -217,7 +218,11 @@ class InputBatch:
|
|||||||
end_idx = start_idx + len(request.output_token_ids)
|
end_idx = start_idx + len(request.output_token_ids)
|
||||||
self.token_ids_cpu[req_index,
|
self.token_ids_cpu[req_index,
|
||||||
start_idx:end_idx] = request.output_token_ids
|
start_idx:end_idx] = request.output_token_ids
|
||||||
|
# Number of token ids in token_ids_cpu.
|
||||||
|
# NOTE(woosuk): This may include spec decode tokens.
|
||||||
self.num_tokens[req_index] = request.num_tokens
|
self.num_tokens[req_index] = request.num_tokens
|
||||||
|
# Number of tokens without spec decode tokens.
|
||||||
|
self.num_tokens_no_spec[req_index] = request.num_tokens
|
||||||
|
|
||||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||||
self.block_table.add_row(req_index, request.block_ids)
|
self.block_table.add_row(req_index, request.block_ids)
|
||||||
@@ -356,6 +361,8 @@ class InputBatch:
|
|||||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||||
last_req_index, :num_tokens]
|
last_req_index, :num_tokens]
|
||||||
self.num_tokens[empty_index] = num_tokens
|
self.num_tokens[empty_index] = num_tokens
|
||||||
|
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
||||||
|
last_req_index]
|
||||||
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
||||||
last_req_index]
|
last_req_index]
|
||||||
self.num_computed_tokens_cpu[
|
self.num_computed_tokens_cpu[
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|||||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
|
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
|
||||||
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
@@ -117,6 +118,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# req_id -> (input_id -> encoder_output)
|
# req_id -> (input_id -> encoder_output)
|
||||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||||
|
|
||||||
|
# Set up speculative decoding.
|
||||||
|
self.use_spec_decode = False
|
||||||
|
if self.speculative_config:
|
||||||
|
# TODO: find a better way to check if we are using ngram.
|
||||||
|
assert self.speculative_config.ngram_prompt_lookup_min, \
|
||||||
|
"Currently, only ngram spec decode is supported in V1."
|
||||||
|
self.drafter = NgramProposer()
|
||||||
|
self.use_spec_decode = True
|
||||||
|
|
||||||
# Request states.
|
# Request states.
|
||||||
self.requests: Dict[str, CachedRequestState] = {}
|
self.requests: Dict[str, CachedRequestState] = {}
|
||||||
# Persistent batch.
|
# Persistent batch.
|
||||||
@@ -367,6 +377,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_batch.token_ids_cpu[
|
self.input_batch.token_ids_cpu[
|
||||||
req_index,
|
req_index,
|
||||||
start_token_index:end_token_index] = req_data.new_token_ids
|
start_token_index:end_token_index] = req_data.new_token_ids
|
||||||
|
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
|
||||||
# Add spec_token_ids to token_ids_cpu.
|
# Add spec_token_ids to token_ids_cpu.
|
||||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||||
req_id, [])
|
req_id, [])
|
||||||
@@ -1009,15 +1020,51 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for seq in sampled_token_ids[valid_mask].split(gen_lens)
|
for seq in sampled_token_ids[valid_mask].split(gen_lens)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if not self.use_spec_decode:
|
||||||
|
spec_token_ids = None
|
||||||
|
else:
|
||||||
|
spec_token_ids = self.generate_draft_token_ids(
|
||||||
|
valid_sampled_token_ids)
|
||||||
|
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
sampled_token_ids=valid_sampled_token_ids,
|
sampled_token_ids=valid_sampled_token_ids,
|
||||||
|
spec_token_ids=spec_token_ids,
|
||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
)
|
)
|
||||||
return model_runner_output
|
return model_runner_output
|
||||||
|
|
||||||
|
def generate_draft_token_ids(
|
||||||
|
self,
|
||||||
|
sampled_token_ids: List[List[int]],
|
||||||
|
) -> List[List[int]]:
|
||||||
|
# TODO(woosuk): Optimize.
|
||||||
|
num_reqs = len(sampled_token_ids)
|
||||||
|
draft_token_ids: List[List[int]] = []
|
||||||
|
for i in range(num_reqs):
|
||||||
|
if len(sampled_token_ids[i]) == 0:
|
||||||
|
# Skip speculative decoding.
|
||||||
|
draft_token_ids.append([])
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add sampled_token_ids to token_ids_cpu.
|
||||||
|
start_idx = self.input_batch.num_tokens_no_spec[i]
|
||||||
|
end_idx = start_idx + len(sampled_token_ids[i])
|
||||||
|
self.input_batch.token_ids_cpu[
|
||||||
|
i, start_idx:end_idx] = sampled_token_ids[i]
|
||||||
|
drafter_output = self.drafter.propose(
|
||||||
|
self.input_batch.token_ids_cpu[i, :end_idx],
|
||||||
|
self.speculative_config.ngram_prompt_lookup_min,
|
||||||
|
self.speculative_config.num_speculative_tokens,
|
||||||
|
)
|
||||||
|
if drafter_output is None or len(drafter_output) == 0:
|
||||||
|
draft_token_ids.append([])
|
||||||
|
else:
|
||||||
|
draft_token_ids.append(drafter_output.tolist())
|
||||||
|
return draft_token_ids
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
logger.info("Starting to load model %s...", self.model_config.model)
|
logger.info("Starting to load model %s...", self.model_config.model)
|
||||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||||
|
|||||||
@@ -696,6 +696,7 @@ class TPUModelRunner:
|
|||||||
req_ids=all_req_ids,
|
req_ids=all_req_ids,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
sampled_token_ids=[[token_id] for token_id in sampled_token_ids],
|
sampled_token_ids=[[token_id] for token_id in sampled_token_ids],
|
||||||
|
spec_token_ids=None,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
|
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user