[Model Runner V2] Refactor prefill token preparation (#29712)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -78,7 +78,7 @@ class CudaGraphManager:
|
|||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
) -> None:
|
) -> None:
|
||||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||||
input_ids = input_buffers.input_ids.gpu[:num_tokens]
|
input_ids = input_buffers.input_ids[:num_tokens]
|
||||||
positions = input_buffers.positions[:num_tokens]
|
positions = input_buffers.positions[:num_tokens]
|
||||||
attn_metadata = prepare_inputs_to_capture(
|
attn_metadata = prepare_inputs_to_capture(
|
||||||
num_reqs,
|
num_reqs,
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numba
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -30,15 +29,12 @@ class InputBuffers:
|
|||||||
self.pin_memory = pin_memory
|
self.pin_memory = pin_memory
|
||||||
|
|
||||||
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||||
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
|
self.input_ids = torch.zeros(max_num_tokens, dtype=torch.int32, device=device)
|
||||||
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
|
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
|
||||||
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
|
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
|
||||||
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
|
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
|
||||||
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
|
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
|
||||||
|
|
||||||
# Spec decoding.
|
|
||||||
self.next_prefill_tokens = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
|
||||||
|
|
||||||
# Structured outputs.
|
# Structured outputs.
|
||||||
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||||
self.grammar_bitmask = self._make_buffer(
|
self.grammar_bitmask = self._make_buffer(
|
||||||
@@ -120,7 +116,7 @@ class InputBatch:
|
|||||||
input_buffers.seq_lens[num_reqs:] = 0
|
input_buffers.seq_lens[num_reqs:] = 0
|
||||||
seq_lens = input_buffers.seq_lens[:num_reqs]
|
seq_lens = input_buffers.seq_lens[:num_reqs]
|
||||||
|
|
||||||
input_ids = input_buffers.input_ids.copy_to_gpu(num_tokens)
|
input_ids = input_buffers.input_ids[:num_tokens]
|
||||||
positions = input_buffers.positions[:num_tokens]
|
positions = input_buffers.positions[:num_tokens]
|
||||||
# attn_metadata = defaultdict(lambda: None)
|
# attn_metadata = defaultdict(lambda: None)
|
||||||
logits_indices = query_start_loc[1:] - 1
|
logits_indices = query_start_loc[1:] - 1
|
||||||
@@ -146,41 +142,63 @@ class InputBatch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@numba.njit(cache=True)
|
@triton.jit
|
||||||
def _prepare_prefill_inputs(
|
def _prepare_prefill_inputs_kernel(
|
||||||
idx_mapping: np.ndarray, # [B]
|
input_ids_ptr,
|
||||||
query_lens: np.ndarray, # [B]
|
next_prefill_tokens_ptr,
|
||||||
query_start_loc: np.ndarray, # [B + 1]
|
idx_mapping_ptr,
|
||||||
prefill_token_ids: np.ndarray, # [N, max_model_len]
|
query_start_loc_ptr,
|
||||||
num_computed_prefill_tokens: np.ndarray, # [N]
|
prefill_token_ids_ptr,
|
||||||
input_ids: np.ndarray, # [num_input_tokens]
|
prefill_token_ids_stride,
|
||||||
) -> None:
|
prefill_lens_ptr,
|
||||||
num_reqs = idx_mapping.shape[0]
|
num_computed_tokens_ptr,
|
||||||
query_starts = query_start_loc[:num_reqs]
|
BLOCK_SIZE: tl.constexpr,
|
||||||
query_ends = query_start_loc[1 : num_reqs + 1]
|
):
|
||||||
starts = num_computed_prefill_tokens[idx_mapping]
|
batch_idx = tl.program_id(0)
|
||||||
ends = starts + query_lens
|
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||||
for i in range(num_reqs):
|
prefill_len = tl.load(prefill_lens_ptr + req_state_idx)
|
||||||
input_ids[query_starts[i] : query_ends[i]] = prefill_token_ids[
|
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||||
idx_mapping[i], starts[i] : ends[i]
|
if num_computed >= prefill_len:
|
||||||
]
|
# Not prefill.
|
||||||
|
return
|
||||||
|
|
||||||
|
query_start = tl.load(query_start_loc_ptr + batch_idx)
|
||||||
|
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||||
|
query_len = query_end - query_start
|
||||||
|
|
||||||
|
prefill_ptr = prefill_token_ids_ptr + req_state_idx * prefill_token_ids_stride
|
||||||
|
for i in range(0, query_len, BLOCK_SIZE):
|
||||||
|
block = i + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = block < query_len
|
||||||
|
tokens = tl.load(prefill_ptr + num_computed + block, mask=mask)
|
||||||
|
tl.store(input_ids_ptr + query_start + block, tokens, mask=mask)
|
||||||
|
|
||||||
|
next_pos = num_computed + query_len
|
||||||
|
if next_pos < prefill_len:
|
||||||
|
next_token = tl.load(prefill_ptr + next_pos)
|
||||||
|
tl.store(next_prefill_tokens_ptr + req_state_idx, next_token)
|
||||||
|
|
||||||
|
|
||||||
def prepare_prefill_inputs(
|
def prepare_prefill_inputs(
|
||||||
idx_mapping: np.ndarray,
|
input_ids: torch.Tensor,
|
||||||
num_scheduled_tokens: np.ndarray,
|
next_prefill_tokens: torch.Tensor,
|
||||||
query_start_loc: np.ndarray,
|
idx_mapping: torch.Tensor,
|
||||||
prefill_token_ids: np.ndarray,
|
query_start_loc: torch.Tensor,
|
||||||
num_computed_prefill_tokens: np.ndarray,
|
prefill_token_ids: torch.Tensor,
|
||||||
input_ids: np.ndarray,
|
prefill_len: torch.Tensor,
|
||||||
|
num_computed_tokens: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
_prepare_prefill_inputs(
|
num_reqs = idx_mapping.shape[0]
|
||||||
|
_prepare_prefill_inputs_kernel[(num_reqs,)](
|
||||||
|
input_ids,
|
||||||
|
next_prefill_tokens,
|
||||||
idx_mapping,
|
idx_mapping,
|
||||||
num_scheduled_tokens,
|
|
||||||
query_start_loc,
|
query_start_loc,
|
||||||
prefill_token_ids,
|
prefill_token_ids,
|
||||||
num_computed_prefill_tokens,
|
prefill_token_ids.stride(0),
|
||||||
input_ids,
|
prefill_len,
|
||||||
|
num_computed_tokens,
|
||||||
|
BLOCK_SIZE=1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -104,11 +104,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if self.use_async_scheduling:
|
if self.use_async_scheduling:
|
||||||
self.input_prep_event = torch.cuda.Event()
|
self.input_prep_event = torch.cuda.Event()
|
||||||
self.structured_outputs_event = torch.cuda.Event()
|
self.structured_outputs_event = torch.cuda.Event()
|
||||||
self.spec_decode_event = torch.cuda.Event()
|
|
||||||
else:
|
else:
|
||||||
self.input_prep_event = None
|
self.input_prep_event = None
|
||||||
self.structured_outputs_event = None
|
self.structured_outputs_event = None
|
||||||
self.spec_decode_event = None
|
|
||||||
|
|
||||||
if self.speculative_config is not None:
|
if self.speculative_config is not None:
|
||||||
self.do_spec_decode = True
|
self.do_spec_decode = True
|
||||||
@@ -412,9 +410,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
cu_num_new_blocks[i].append(x + len(block_ids))
|
cu_num_new_blocks[i].append(x + len(block_ids))
|
||||||
new_block_ids[i].extend(block_ids)
|
new_block_ids[i].extend(block_ids)
|
||||||
overwrite.append(True)
|
overwrite.append(True)
|
||||||
# Update the GPU tensors for request states.
|
|
||||||
if scheduler_output.scheduled_new_reqs:
|
|
||||||
self.req_states.prefill_len.copy_to_gpu()
|
|
||||||
|
|
||||||
# Add new blocks for the existing requests.
|
# Add new blocks for the existing requests.
|
||||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||||
@@ -507,16 +502,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
|
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
|
||||||
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
|
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
|
||||||
|
|
||||||
# Copy prefill tokens from CPU to GPU.
|
# Get prefill tokens.
|
||||||
prepare_prefill_inputs(
|
prepare_prefill_inputs(
|
||||||
idx_mapping_np,
|
self.input_buffers.input_ids,
|
||||||
num_scheduled_tokens,
|
self.req_states.next_prefill_tokens,
|
||||||
query_start_loc_np,
|
idx_mapping,
|
||||||
self.req_states.prefill_token_ids.np,
|
query_start_loc_gpu,
|
||||||
self.req_states.num_computed_prefill_tokens,
|
self.req_states.prefill_token_ids.gpu,
|
||||||
self.input_buffers.input_ids.np,
|
self.req_states.prefill_len.gpu,
|
||||||
|
self.req_states.num_computed_tokens,
|
||||||
)
|
)
|
||||||
self.input_buffers.input_ids.copy_to_gpu(num_tokens)
|
|
||||||
|
|
||||||
# Prepare positions and seq_lens.
|
# Prepare positions and seq_lens.
|
||||||
prepare_pos_seq_lens(
|
prepare_pos_seq_lens(
|
||||||
@@ -531,7 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Some input token ids are directly read from the last sampled tokens
|
# Some input token ids are directly read from the last sampled tokens
|
||||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||||
logits_indices = combine_sampled_and_draft_tokens(
|
logits_indices = combine_sampled_and_draft_tokens(
|
||||||
self.input_buffers.input_ids.gpu,
|
self.input_buffers.input_ids,
|
||||||
idx_mapping,
|
idx_mapping,
|
||||||
self.req_states.last_sampled_tokens,
|
self.req_states.last_sampled_tokens,
|
||||||
query_start_loc_gpu,
|
query_start_loc_gpu,
|
||||||
@@ -572,7 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_cache_config=self.kv_cache_config,
|
kv_cache_config=self.kv_cache_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
|
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
|
||||||
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
positions = self.input_buffers.positions[:num_tokens_after_padding]
|
||||||
return InputBatch(
|
return InputBatch(
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
@@ -782,20 +777,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_sampled: torch.Tensor,
|
num_sampled: torch.Tensor,
|
||||||
num_rejected: torch.Tensor,
|
num_rejected: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_reqs = input_batch.num_reqs
|
|
||||||
idx_mapping_np = input_batch.idx_mapping_np
|
|
||||||
with async_barrier(self.spec_decode_event):
|
|
||||||
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
|
|
||||||
self.req_states.prefill_token_ids.np[
|
|
||||||
idx_mapping_np,
|
|
||||||
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu(
|
|
||||||
num_reqs
|
|
||||||
)
|
|
||||||
|
|
||||||
assert self.speculator is not None
|
assert self.speculator is not None
|
||||||
|
last_sampled_tokens = self.req_states.last_sampled_tokens[
|
||||||
|
input_batch.idx_mapping
|
||||||
|
]
|
||||||
|
next_prefill_tokens = self.req_states.next_prefill_tokens[
|
||||||
|
input_batch.idx_mapping
|
||||||
|
]
|
||||||
draft_tokens = self.speculator.propose(
|
draft_tokens = self.speculator.propose(
|
||||||
input_batch,
|
input_batch,
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
@@ -803,7 +791,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
aux_hidden_states,
|
aux_hidden_states,
|
||||||
num_sampled,
|
num_sampled,
|
||||||
num_rejected,
|
num_rejected,
|
||||||
self.req_states.last_sampled_tokens,
|
last_sampled_tokens,
|
||||||
next_prefill_tokens,
|
next_prefill_tokens,
|
||||||
)
|
)
|
||||||
return draft_tokens
|
return draft_tokens
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class EagleSpeculator:
|
|||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
):
|
):
|
||||||
ret_hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
input_ids=self.input_buffers.input_ids.gpu[:num_tokens],
|
input_ids=self.input_buffers.input_ids[:num_tokens],
|
||||||
positions=self.input_buffers.positions[:num_tokens],
|
positions=self.input_buffers.positions[:num_tokens],
|
||||||
hidden_states=self.hidden_states[:num_tokens],
|
hidden_states=self.hidden_states[:num_tokens],
|
||||||
)
|
)
|
||||||
@@ -194,7 +194,7 @@ class EagleSpeculator:
|
|||||||
num_sampled: torch.Tensor,
|
num_sampled: torch.Tensor,
|
||||||
# [num_reqs]
|
# [num_reqs]
|
||||||
num_rejected: torch.Tensor,
|
num_rejected: torch.Tensor,
|
||||||
# [max_num_reqs, 1]
|
# [num_reqs]
|
||||||
last_sampled: torch.Tensor,
|
last_sampled: torch.Tensor,
|
||||||
# [num_reqs]
|
# [num_reqs]
|
||||||
next_prefill_tokens: torch.Tensor,
|
next_prefill_tokens: torch.Tensor,
|
||||||
@@ -316,7 +316,6 @@ def _prepare_eagle_inputs_kernel(
|
|||||||
eagle_positions_ptr,
|
eagle_positions_ptr,
|
||||||
target_input_ids_ptr,
|
target_input_ids_ptr,
|
||||||
target_positions_ptr,
|
target_positions_ptr,
|
||||||
idx_mapping_ptr,
|
|
||||||
last_sampled_ptr,
|
last_sampled_ptr,
|
||||||
next_prefill_tokens_ptr,
|
next_prefill_tokens_ptr,
|
||||||
num_sampled_ptr,
|
num_sampled_ptr,
|
||||||
@@ -335,8 +334,7 @@ def _prepare_eagle_inputs_kernel(
|
|||||||
|
|
||||||
num_sampled = tl.load(num_sampled_ptr + batch_idx)
|
num_sampled = tl.load(num_sampled_ptr + batch_idx)
|
||||||
if num_sampled > 0:
|
if num_sampled > 0:
|
||||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
next_token = tl.load(last_sampled_ptr + batch_idx).to(tl.int32)
|
||||||
next_token = tl.load(last_sampled_ptr + req_state_idx).to(tl.int32)
|
|
||||||
else:
|
else:
|
||||||
# Chunked prefilling.
|
# Chunked prefilling.
|
||||||
# Get the next prefill token.
|
# Get the next prefill token.
|
||||||
@@ -368,9 +366,9 @@ def prepare_eagle_inputs(
|
|||||||
num_sampled: torch.Tensor,
|
num_sampled: torch.Tensor,
|
||||||
# [num_reqs]
|
# [num_reqs]
|
||||||
num_rejected: torch.Tensor,
|
num_rejected: torch.Tensor,
|
||||||
# [max_num_reqs, 1]
|
# [num_reqs]
|
||||||
last_sampled: torch.Tensor,
|
last_sampled: torch.Tensor,
|
||||||
# [max_num_reqs]
|
# [num_reqs]
|
||||||
next_prefill_tokens: torch.Tensor,
|
next_prefill_tokens: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_reqs = input_batch.num_reqs
|
num_reqs = input_batch.num_reqs
|
||||||
@@ -381,11 +379,10 @@ def prepare_eagle_inputs(
|
|||||||
)
|
)
|
||||||
_prepare_eagle_inputs_kernel[(num_reqs,)](
|
_prepare_eagle_inputs_kernel[(num_reqs,)](
|
||||||
last_token_indices,
|
last_token_indices,
|
||||||
input_buffers.input_ids.gpu,
|
input_buffers.input_ids,
|
||||||
input_buffers.positions,
|
input_buffers.positions,
|
||||||
input_batch.input_ids,
|
input_batch.input_ids,
|
||||||
input_batch.positions,
|
input_batch.positions,
|
||||||
input_batch.idx_mapping,
|
|
||||||
last_sampled,
|
last_sampled,
|
||||||
next_prefill_tokens,
|
next_prefill_tokens,
|
||||||
num_sampled,
|
num_sampled,
|
||||||
@@ -485,7 +482,7 @@ def prepare_eagle_decode(
|
|||||||
last_token_indices,
|
last_token_indices,
|
||||||
target_seq_lens,
|
target_seq_lens,
|
||||||
num_rejected,
|
num_rejected,
|
||||||
input_buffers.input_ids.gpu,
|
input_buffers.input_ids,
|
||||||
input_buffers.positions,
|
input_buffers.positions,
|
||||||
input_hidden_states,
|
input_hidden_states,
|
||||||
input_hidden_states.stride(0),
|
input_hidden_states.stride(0),
|
||||||
@@ -553,7 +550,7 @@ def update_eagle_inputs(
|
|||||||
):
|
):
|
||||||
num_reqs, hidden_size = output_hidden_states.shape
|
num_reqs, hidden_size = output_hidden_states.shape
|
||||||
_update_eagle_inputs_kernel[(num_reqs,)](
|
_update_eagle_inputs_kernel[(num_reqs,)](
|
||||||
input_buffers.input_ids.gpu,
|
input_buffers.input_ids,
|
||||||
input_buffers.positions,
|
input_buffers.positions,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
hidden_states.stride(0),
|
hidden_states.stride(0),
|
||||||
|
|||||||
@@ -117,8 +117,7 @@ class RequestState:
|
|||||||
self.prefill_token_ids = UvaBuffer(
|
self.prefill_token_ids = UvaBuffer(
|
||||||
self.max_num_reqs, self.max_model_len, dtype=torch.int32
|
self.max_num_reqs, self.max_model_len, dtype=torch.int32
|
||||||
)
|
)
|
||||||
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
self.prefill_len = UvaBuffer(self.max_num_reqs, dtype=torch.int32)
|
||||||
|
|
||||||
# Number of computed tokens.
|
# Number of computed tokens.
|
||||||
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||||
self.num_computed_tokens = torch.zeros(
|
self.num_computed_tokens = torch.zeros(
|
||||||
@@ -140,6 +139,9 @@ class RequestState:
|
|||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
self.next_prefill_tokens = torch.zeros(
|
||||||
|
self.max_num_reqs, dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
# LoRA.
|
# LoRA.
|
||||||
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
|
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||||
@@ -380,13 +382,13 @@ def _expand_sampling_metadata_kernel(
|
|||||||
expanded_top_p_ptr,
|
expanded_top_p_ptr,
|
||||||
top_k_ptr,
|
top_k_ptr,
|
||||||
expanded_top_k_ptr,
|
expanded_top_k_ptr,
|
||||||
seeds_ptr,
|
|
||||||
rep_penalty_ptr,
|
rep_penalty_ptr,
|
||||||
expanded_rep_penalty_ptr,
|
expanded_rep_penalty_ptr,
|
||||||
freq_penalty_ptr,
|
freq_penalty_ptr,
|
||||||
expanded_freq_penalty_ptr,
|
expanded_freq_penalty_ptr,
|
||||||
pres_penalty_ptr,
|
pres_penalty_ptr,
|
||||||
expanded_pres_penalty_ptr,
|
expanded_pres_penalty_ptr,
|
||||||
|
seeds_ptr,
|
||||||
expanded_seeds_ptr,
|
expanded_seeds_ptr,
|
||||||
cu_num_logits_ptr,
|
cu_num_logits_ptr,
|
||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
|||||||
Reference in New Issue
Block a user