diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index d90b0dc01..2fddbd01d 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -156,8 +156,8 @@ def _prepare_prefill_inputs_kernel( next_prefill_tokens_ptr, idx_mapping_ptr, query_start_loc_ptr, - prefill_token_ids_ptr, - prefill_token_ids_stride, + all_token_ids_ptr, + all_token_ids_stride, prefill_lens_ptr, num_computed_tokens_ptr, BLOCK_SIZE: tl.constexpr, @@ -174,16 +174,16 @@ def _prepare_prefill_inputs_kernel( query_end = tl.load(query_start_loc_ptr + batch_idx + 1) query_len = query_end - query_start - prefill_ptr = prefill_token_ids_ptr + req_state_idx * prefill_token_ids_stride + request_ptr = all_token_ids_ptr + req_state_idx * all_token_ids_stride for i in range(0, query_len, BLOCK_SIZE): block = i + tl.arange(0, BLOCK_SIZE) mask = block < query_len - tokens = tl.load(prefill_ptr + num_computed + block, mask=mask) + tokens = tl.load(request_ptr + num_computed + block, mask=mask) tl.store(input_ids_ptr + query_start + block, tokens, mask=mask) next_pos = num_computed + query_len if next_pos < prefill_len: - next_token = tl.load(prefill_ptr + next_pos) + next_token = tl.load(request_ptr + next_pos) tl.store(next_prefill_tokens_ptr + req_state_idx, next_token) @@ -192,7 +192,7 @@ def prepare_prefill_inputs( next_prefill_tokens: torch.Tensor, idx_mapping: torch.Tensor, query_start_loc: torch.Tensor, - prefill_token_ids: torch.Tensor, + all_token_ids: torch.Tensor, prefill_len: torch.Tensor, num_computed_tokens: torch.Tensor, ) -> None: @@ -202,8 +202,8 @@ def prepare_prefill_inputs( next_prefill_tokens, idx_mapping, query_start_loc, - prefill_token_ids, - prefill_token_ids.stride(0), + all_token_ids, + all_token_ids.stride(0), prefill_len, num_computed_tokens, BLOCK_SIZE=1024, @@ -423,16 +423,21 @@ def _post_update_kernel( num_sampled_ptr, num_rejected_ptr, query_start_loc_ptr, + all_token_ids_ptr, + all_token_ids_stride, + total_len_ptr, ): req_id = tl.program_id(0) req_state_idx = tl.load(idx_mapping_ptr + req_id) + total_len = tl.load(total_len_ptr + req_state_idx) num_sampled = tl.load(num_sampled_ptr + req_id) if num_sampled > 0: token_id = tl.load( sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1 ) tl.store(last_sampled_tokens_ptr + req_state_idx, token_id) + tl.store(total_len_ptr + req_state_idx, total_len + num_sampled) for i in range(num_sampled): token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i) @@ -442,6 +447,10 @@ def _post_update_kernel( count = tl.load(token_ptr) count += 1 tl.store(token_ptr, count) + tl.store( + all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i, + token_id, + ) query_start = tl.load(query_start_loc_ptr + req_id) query_end = tl.load(query_start_loc_ptr + req_id + 1) @@ -470,6 +479,10 @@ def post_update( num_rejected: torch.Tensor, # [num_reqs + 1] query_start_loc: torch.Tensor, + # [max_num_reqs, max_model_len] + all_token_ids: torch.Tensor, + # [max_num_reqs] + total_len: torch.Tensor, ) -> None: num_reqs = idx_mapping.shape[0] _post_update_kernel[(num_reqs,)]( @@ -483,6 +496,9 @@ def post_update( num_sampled, num_rejected, query_start_loc, + all_token_ids, + all_token_ids.stride(0), + total_len, num_warps=1, ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index d6b87bd71..380da12cd 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -151,6 +151,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): max_num_reqs=self.max_num_reqs, vocab_size=self.vocab_size, device=self.device, + all_token_ids=self.req_states.all_token_ids.gpu, + prompt_len=self.req_states.prompt_len.gpu, + total_len=self.req_states.total_len.gpu, logprobs_mode=self.model_config.logprobs_mode, num_speculative_tokens=self.num_speculative_steps + 1, ) @@ -448,7 +451,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.req_states.add_request( req_id=req_id, prompt_len=prompt_len, - prefill_token_ids=new_req_data.prefill_token_ids, + all_token_ids=new_req_data.prefill_token_ids, num_computed_tokens=new_req_data.num_computed_tokens, ) req_index = self.req_states.req_id_to_index[req_id] @@ -479,9 +482,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): if scheduler_output.scheduled_new_reqs: self.req_states.apply_staged_writes() self.sampler.apply_staged_writes( - self.req_states.prefill_token_ids.gpu, + self.req_states.all_token_ids.gpu, self.req_states.prefill_len.np, - self.req_states.prompt_len, + self.req_states.prompt_len.np, ) if self.uses_mrope: self.mrope_states.apply_staged_writes() @@ -570,7 +573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.req_states.next_prefill_tokens, idx_mapping, query_start_loc, - self.req_states.prefill_token_ids.gpu, + self.req_states.all_token_ids.gpu, self.req_states.prefill_len.gpu, self.req_states.num_computed_tokens.gpu, ) @@ -759,6 +762,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_sampled, num_rejected, input_batch.query_start_loc, + self.req_states.all_token_ids.gpu, + self.req_states.total_len.gpu, ) # Update the number of computed prefill tokens. @@ -924,9 +929,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.model.compute_logits, hidden_states, input_batch, - self.req_states.prefill_token_ids.gpu, + self.req_states.all_token_ids.gpu, self.req_states.num_computed_tokens.gpu, - self.req_states.prompt_len, + self.req_states.prompt_len.np, self.req_states.prefill_len.np, self.req_states.num_computed_prefill_tokens, ) diff --git a/vllm/v1/worker/gpu/sample/bad_words.py b/vllm/v1/worker/gpu/sample/bad_words.py new file mode 100644 index 000000000..c6f8f8af2 --- /dev/null +++ b/vllm/v1/worker/gpu/sample/bad_words.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np +import torch + +from vllm.sampling_params import SamplingParams +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor + +MAX_BAD_WORDS_TOTAL_TOKENS = 1024 # Max total tokens for all bad words per request +MAX_NUM_BAD_WORDS = 128 # Max number of bad words per request + + +class BadWordsState: + def __init__( + self, + all_token_ids: torch.Tensor, + prompt_len: torch.Tensor, + total_len: torch.Tensor, + ): + self.all_token_ids = all_token_ids + self.prompt_len = prompt_len + self.total_len = total_len + + self.max_num_reqs = prompt_len.shape[0] + self.device = prompt_len.device + + # flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS] + self.bad_word_token_ids = StagedWriteTensor( + (self.max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS), + dtype=torch.int32, + device=self.device, + ) + # cumulative offsets of bad words: [max_num_reqs, MAX_NUM_BAD_WORDS + 1] + self.bad_word_offsets = StagedWriteTensor( + (self.max_num_reqs, MAX_NUM_BAD_WORDS + 1), + dtype=torch.int32, + device=self.device, + ) + # number of bad words per request + self.num_bad_words = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) + # whether request uses bad words + self.use_bad_words = np.zeros(self.max_num_reqs, dtype=bool) + + def add_request( + self, + req_idx: int, + sampling_params: SamplingParams, + ) -> None: + bad_words_token_ids = sampling_params.bad_words_token_ids + if not bad_words_token_ids: + self.num_bad_words.np[req_idx] = 0 + self.use_bad_words[req_idx] = False + return + + num_bad_words = len(bad_words_token_ids) + if num_bad_words > MAX_NUM_BAD_WORDS: + raise ValueError( + f"Too many bad words: {num_bad_words}. " + f"The max number is {MAX_NUM_BAD_WORDS}." + ) + + # Flatten bad words and compute offsets + flattened_tokens: list[int] = [] + offsets: list[int] = [0] + for bad_word in bad_words_token_ids: + flattened_tokens.extend(bad_word) + offsets.append(len(flattened_tokens)) + + if len(flattened_tokens) > MAX_BAD_WORDS_TOTAL_TOKENS: + raise ValueError( + f"Too many total bad word tokens: {len(flattened_tokens)}. " + f"The max is {MAX_BAD_WORDS_TOTAL_TOKENS}." + ) + + # Stage writes + self.bad_word_token_ids.stage_write(req_idx, 0, flattened_tokens) + self.bad_word_offsets.stage_write(req_idx, 0, offsets) + self.num_bad_words.np[req_idx] = num_bad_words + self.use_bad_words[req_idx] = True + + def apply_staged_writes(self) -> None: + self.num_bad_words.copy_to_uva() + self.bad_word_token_ids.apply_write() + self.bad_word_offsets.apply_write() + + def apply_bad_words( + self, + logits: torch.Tensor, + idx_mapping: torch.Tensor, + idx_mapping_np: np.ndarray, + input_ids: torch.Tensor, + expanded_local_pos: torch.Tensor, + ) -> None: + if not np.any(self.use_bad_words[idx_mapping_np]): + # No request uses bad words. Skip the kernel launch. + return + + actual_max_num_bad_words = int(np.max(self.num_bad_words.np[idx_mapping_np])) + apply_bad_words( + logits, + idx_mapping, + self.bad_word_token_ids.gpu, + self.bad_word_offsets.gpu, + self.num_bad_words.gpu, + self.all_token_ids, + self.prompt_len, + self.total_len, + input_ids, + expanded_local_pos, + actual_max_num_bad_words, + ) + + +@triton.jit +def _bad_words_kernel( + logits_ptr, + logits_stride, + expanded_idx_mapping_ptr, + bad_word_token_ids_ptr, + bad_word_token_ids_stride, + bad_word_offsets_ptr, + bad_word_offsets_stride, + num_bad_words_ptr, + all_token_ids_ptr, + all_token_ids_stride, + prompt_len_ptr, + total_len_ptr, + input_ids_ptr, + expanded_local_pos_ptr, +): + logit_idx = tl.program_id(0) + bw_idx = tl.program_id(1) + + req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx) + num_bad_words = tl.load(num_bad_words_ptr + req_state_idx) + + if bw_idx >= num_bad_words: + return + + pos = tl.load(expanded_local_pos_ptr + logit_idx) + cur_req_first_pos = logit_idx - pos + + prompt_len = tl.load(prompt_len_ptr + req_state_idx) + total_len = tl.load(total_len_ptr + req_state_idx) + output_len = total_len - prompt_len + effective_len = output_len + pos + + bd_offsets_base = bad_word_offsets_ptr + req_state_idx * bad_word_offsets_stride + bd_tokens_base = bad_word_token_ids_ptr + req_state_idx * bad_word_token_ids_stride + output_base = all_token_ids_ptr + req_state_idx * all_token_ids_stride + prompt_len + + start = tl.load(bd_offsets_base + bw_idx) + end = tl.load(bd_offsets_base + bw_idx + 1) + bad_word_len = end - start + prefix_len = bad_word_len - 1 + + if prefix_len > effective_len: + return + + last_token = tl.load(bd_tokens_base + end - 1) + match = 1 + for i in range(prefix_len): + expected = tl.load(bd_tokens_base + start + i) + actual_pos = effective_len - prefix_len + i + + from_spec_input = actual_pos >= output_len + if from_spec_input: + spec_offset = actual_pos - output_len + actual = tl.load(input_ids_ptr + cur_req_first_pos + spec_offset) + else: + actual = tl.load(output_base + actual_pos) + + match = match & (expected == actual) + + if match: + tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf")) + + +def apply_bad_words( + logits: torch.Tensor, + expanded_idx_mapping: torch.Tensor, + bad_word_token_ids: torch.Tensor, + bad_word_offsets: torch.Tensor, + num_bad_words: torch.Tensor, + all_token_ids: torch.Tensor, + prompt_len: torch.Tensor, + total_len: torch.Tensor, + input_ids: torch.Tensor, + expanded_local_pos: torch.Tensor, + max_num_bad_words: int, +) -> None: + total_num_tokens = logits.shape[0] + _bad_words_kernel[(total_num_tokens, max_num_bad_words)]( + logits, + logits.stride(0), + expanded_idx_mapping, + bad_word_token_ids, + bad_word_token_ids.stride(0), + bad_word_offsets, + bad_word_offsets.stride(0), + num_bad_words, + all_token_ids, + all_token_ids.stride(0), + prompt_len, + total_len, + input_ids, + expanded_local_pos, + ) diff --git a/vllm/v1/worker/gpu/sample/penalties.py b/vllm/v1/worker/gpu/sample/penalties.py index 24928fd10..8671dd7e0 100644 --- a/vllm/v1/worker/gpu/sample/penalties.py +++ b/vllm/v1/worker/gpu/sample/penalties.py @@ -51,14 +51,14 @@ class PenaltiesState: def apply_staged_writes( self, - prefill_token_ids: torch.Tensor, + all_token_ids: torch.Tensor, prefill_lens: np.ndarray, prompt_lens: np.ndarray, ) -> None: # TODO(woosuk): Optimize this. for req_idx in self._penalties_reqs: bincount( - prefill_token_ids[req_idx], + all_token_ids[req_idx], int(prefill_lens[req_idx]), int(prompt_lens[req_idx]), self.prompt_bin_mask[req_idx], @@ -216,7 +216,7 @@ def apply_penalties( @triton.jit(do_not_specialize=["prefill_len", "prompt_len"]) def _bincount_kernel( - prefill_token_ids_ptr, + all_token_ids_ptr, prefill_len, prompt_len, prompt_bin_mask_ptr, @@ -230,20 +230,20 @@ def _bincount_kernel( block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) if block_idx * BLOCK_SIZE < prompt_len: mask = block < prompt_len - prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask) - idx = prefill_tokens // 32 - bit_idx = prefill_tokens % 32 + prompt_tokens = tl.load(all_token_ids_ptr + block, mask=mask) + idx = prompt_tokens // 32 + bit_idx = prompt_tokens % 32 bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx tl.atomic_or(prompt_bin_mask_ptr + idx, bit, mask=mask) if (block_idx + 1) * BLOCK_SIZE >= prompt_len: mask = block < prefill_len mask &= block >= prompt_len - prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask) - tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask) + output_tokens = tl.load(all_token_ids_ptr + block, mask=mask) + tl.atomic_add(output_bin_counts_ptr + output_tokens, 1, mask=mask) def bincount( - prefill_token_ids: torch.Tensor, + all_token_ids: torch.Tensor, prefill_len: int, prompt_len: int, prompt_bin_mask: torch.Tensor, @@ -254,7 +254,7 @@ def bincount( BLOCK_SIZE = 1024 num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE) _bincount_kernel[(num_blocks,)]( - prefill_token_ids, + all_token_ids, prefill_len, prompt_len, prompt_bin_mask, diff --git a/vllm/v1/worker/gpu/sample/prompt_logprob.py b/vllm/v1/worker/gpu/sample/prompt_logprob.py index 76b9af3a3..1915a0539 100644 --- a/vllm/v1/worker/gpu/sample/prompt_logprob.py +++ b/vllm/v1/worker/gpu/sample/prompt_logprob.py @@ -36,7 +36,7 @@ class PromptLogprobsWorker: hidden_states: torch.Tensor, input_batch: InputBatch, # [max_num_reqs, max_model_len] - prefill_token_ids: torch.Tensor, + all_token_ids: torch.Tensor, # [max_num_reqs] num_computed_tokens: torch.Tensor, # [max_num_reqs] @@ -70,7 +70,7 @@ class PromptLogprobsWorker: input_batch.query_start_loc, input_batch.idx_mapping, num_computed_tokens, - prefill_token_ids, + all_token_ids, ) # Compute the prompt logprobs. prompt_logprobs, prompt_ranks = compute_prompt_logprobs_with_chunking( @@ -132,8 +132,8 @@ def _prompt_logprobs_token_ids_kernel( query_start_loc_ptr, idx_mapping_ptr, num_computed_tokens_ptr, - prefill_token_ids_ptr, - prefill_token_ids_stride, + all_token_ids_ptr, + all_token_ids_stride, BLOCK_SIZE: tl.constexpr, ): batch_idx = tl.program_id(0) @@ -151,9 +151,7 @@ def _prompt_logprobs_token_ids_kernel( # because the logprob is computed for the next token. target_pos = num_computed_tokens + 1 + block token_ids = tl.load( - prefill_token_ids_ptr - + req_state_idx * prefill_token_ids_stride - + target_pos, + all_token_ids_ptr + req_state_idx * all_token_ids_stride + target_pos, mask=mask, ) tl.store( @@ -166,7 +164,7 @@ def get_prompt_logprobs_token_ids( query_start_loc: torch.Tensor, idx_mapping: torch.Tensor, num_computed_tokens: torch.Tensor, - prefill_token_ids: torch.Tensor, + all_token_ids: torch.Tensor, ) -> torch.Tensor: token_ids = torch.empty(num_tokens, dtype=torch.int64, device=idx_mapping.device) num_reqs = idx_mapping.shape[0] @@ -175,8 +173,8 @@ def get_prompt_logprobs_token_ids( query_start_loc, idx_mapping, num_computed_tokens, - prefill_token_ids, - prefill_token_ids.stride(0), + all_token_ids, + all_token_ids.stride(0), BLOCK_SIZE=1024, ) return token_ids diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 5935446f8..d5f66a39e 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -8,6 +8,7 @@ import vllm.envs as envs from vllm.config.model import LogprobsMode from vllm.sampling_params import SamplingParams from vllm.v1.worker.gpu.metrics.logits import get_num_nans +from vllm.v1.worker.gpu.sample.bad_words import BadWordsState from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs @@ -22,6 +23,9 @@ class Sampler: max_num_reqs: int, vocab_size: int, device: torch.device, + all_token_ids: torch.Tensor, + prompt_len: torch.Tensor, + total_len: torch.Tensor, logprobs_mode: LogprobsMode = "raw_logprobs", num_speculative_tokens: int = 1, ): @@ -33,6 +37,7 @@ class Sampler: self.sampling_states = SamplingStates(max_num_reqs, vocab_size) self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device) self.logit_bias_state = LogitBiasState(max_num_reqs, device) + self.bad_words_state = BadWordsState(all_token_ids, prompt_len, total_len) self.num_speculative_tokens = num_speculative_tokens def add_request( @@ -41,18 +46,20 @@ class Sampler: self.sampling_states.add_request(req_idx, sampling_params) self.penalties_state.add_request(req_idx, sampling_params) self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params) + self.bad_words_state.add_request(req_idx, sampling_params) def apply_staged_writes( self, - prefill_token_ids: torch.Tensor, + all_token_ids: torch.Tensor, prefill_lens: np.ndarray, prompt_lens: np.ndarray, ) -> None: self.sampling_states.apply_staged_writes() self.penalties_state.apply_staged_writes( - prefill_token_ids, prefill_lens, prompt_lens + all_token_ids, prefill_lens, prompt_lens ) self.logit_bias_state.apply_staged_writes() + self.bad_words_state.apply_staged_writes() def __call__( self, @@ -124,6 +131,15 @@ class Sampler: self.num_speculative_tokens, ) + # Apply bad words masking in place. + self.bad_words_state.apply_bad_words( + logits, + idx_mapping, + idx_mapping_np, + input_ids, + expanded_local_pos, + ) + # Apply temperature in place. self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np) diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 5379aae72..b4bc8d4d4 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -27,17 +27,30 @@ class RequestState: self.index_to_req_id: dict[int, str] = {} self.free_indices = list(range(max_num_reqs)) - self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32) # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) # depending on the configured max_num_reqs and max_model_len. # To save GPU memory, we use UVA instead of GPU for this tensor. - self.prefill_token_ids = StagedWriteTensor( + self.all_token_ids = StagedWriteTensor( (self.max_num_reqs, self.max_model_len), dtype=torch.int32, device=device, uva_instead_of_gpu=True, ) + # NOTE(woosuk): Distinguish clearly between prompt_len and prefill_len: + # - prompt_len: Number of tokens in the user-provided prompt. + # - prefill_len: Number of tokens passed into the model runner. + # This can include the prompt and additional partial output tokens, + # so prefill_len >= prompt_len. + # Usually, prefill_len equals prompt_len, but in cases such as resumption after + # preemption, prefill_len may be greater. Differentiating between these values + # is crucial, as certain features such as prompt logprobs or frequency penalties + # must treat prompt and output tokens separately. + self.prompt_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32) + # total_len = prompt_len + output_len. It grows as the request progresses. + self.total_len = StagedWriteTensor( + self.max_num_reqs, dtype=torch.int32, device=device + ) # Number of computed tokens. self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) @@ -72,7 +85,7 @@ class RequestState: self, req_id: str, prompt_len: int, - prefill_token_ids: list[int], + all_token_ids: list[int], num_computed_tokens: int, ) -> None: assert len(self.free_indices) > 0, "No free indices" @@ -80,19 +93,22 @@ class RequestState: self.req_id_to_index[req_id] = req_idx self.index_to_req_id[req_idx] = req_id - self.prompt_len[req_idx] = prompt_len - prefill_len = len(prefill_token_ids) + self.prompt_len.np[req_idx] = prompt_len + prefill_len = len(all_token_ids) assert prefill_len >= prompt_len, ( f"prefill_len {prefill_len} < prompt_len {prompt_len}" ) self.prefill_len.np[req_idx] = prefill_len - self.prefill_token_ids.stage_write(req_idx, 0, prefill_token_ids) + self.total_len.stage_write_elem(req_idx, prefill_len) + self.all_token_ids.stage_write(req_idx, 0, all_token_ids) self.num_computed_prefill_tokens[req_idx] = num_computed_tokens self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens) def apply_staged_writes(self) -> None: + self.prompt_len.copy_to_uva() self.prefill_len.copy_to_uva() - self.prefill_token_ids.apply_write() + self.total_len.apply_write() + self.all_token_ids.apply_write() self.num_computed_tokens.apply_write() def remove_request(self, req_id: str) -> None: