[Model Runner V2] Some code simplification (#36929)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -57,10 +57,7 @@ SpeculativeMethod = Literal[
|
||||
EagleModelTypes,
|
||||
NgramGPUTypes,
|
||||
]
|
||||
RejectionSampleMethod = Literal[
|
||||
"strict",
|
||||
"probabilistic",
|
||||
]
|
||||
RejectionSampleMethod = Literal["strict", "probabilistic"]
|
||||
|
||||
|
||||
@config
|
||||
|
||||
@@ -81,7 +81,7 @@ def _gumbel_sample_kernel(
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||
if (temp != 0.0) and APPLY_TEMPERATURE:
|
||||
if temp != 0.0 and APPLY_TEMPERATURE:
|
||||
# Apply temperature.
|
||||
# NOTE(woosuk): Match the behavior of _temperature_kernel.
|
||||
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
|
||||
@@ -127,18 +127,8 @@ def gumbel_sample(
|
||||
num_tokens, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = torch.empty(
|
||||
num_tokens,
|
||||
num_blocks,
|
||||
dtype=torch.int64,
|
||||
device=logits.device,
|
||||
)
|
||||
local_max = torch.empty(
|
||||
num_tokens,
|
||||
num_blocks,
|
||||
dtype=torch.float32,
|
||||
device=logits.device,
|
||||
)
|
||||
local_argmax = logits.new_empty(num_tokens, num_blocks, dtype=torch.int64)
|
||||
local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float32)
|
||||
_gumbel_sample_kernel[(num_tokens, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
|
||||
@@ -53,17 +53,8 @@ def strict_rejection_sample(
|
||||
num_speculative_steps,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
sampled = torch.empty(
|
||||
num_reqs,
|
||||
num_speculative_steps + 1,
|
||||
dtype=target_sampled.dtype,
|
||||
device=target_sampled.device,
|
||||
)
|
||||
num_sampled = torch.empty(
|
||||
num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=target_sampled.device,
|
||||
)
|
||||
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
|
||||
num_sampled = target_sampled.new_empty(num_reqs, dtype=torch.int32)
|
||||
_strict_rejection_sample_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
sampled.stride(0),
|
||||
@@ -216,12 +207,11 @@ def probabilistic_rejection_sample(
|
||||
pos: torch.Tensor,
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
temperature,
|
||||
seeds,
|
||||
num_speculative_steps,
|
||||
temperature: torch.Tensor,
|
||||
seed: torch.Tensor,
|
||||
num_speculative_steps: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
device = target_logits.device
|
||||
vocab_size = target_logits.shape[-1]
|
||||
|
||||
# Compute target and draft probs.
|
||||
@@ -230,18 +220,11 @@ def probabilistic_rejection_sample(
|
||||
|
||||
# Rejection sample.
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled = torch.empty(
|
||||
num_reqs,
|
||||
num_speculative_steps + 1,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
sampled = draft_sampled.new_empty(
|
||||
num_reqs, num_speculative_steps + 1, dtype=torch.int64
|
||||
)
|
||||
# [num_reqs]
|
||||
rejected_steps = torch.empty(
|
||||
num_reqs,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
rejected_steps = sampled.new_empty(num_reqs)
|
||||
_probabilistic_rejection_sample_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
sampled.stride(0),
|
||||
@@ -255,25 +238,16 @@ def probabilistic_rejection_sample(
|
||||
cu_num_logits,
|
||||
pos,
|
||||
idx_mapping,
|
||||
seeds,
|
||||
seed,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
# Compute the logits and positions to resample the rejected/bonus
|
||||
# tokens from.
|
||||
# [num_reqs, vocab_size]
|
||||
residual_logits = torch.empty(
|
||||
num_reqs,
|
||||
vocab_size,
|
||||
dtype=target_logits.dtype,
|
||||
device=device,
|
||||
)
|
||||
residual_logits = target_logits.new_empty(num_reqs, vocab_size)
|
||||
# [num_reqs]
|
||||
residual_pos = torch.empty(
|
||||
num_reqs,
|
||||
dtype=pos.dtype,
|
||||
device=device,
|
||||
)
|
||||
residual_pos = pos.new_empty(num_reqs)
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
_compute_residual_logits_kernel[(num_reqs, num_blocks)](
|
||||
@@ -299,7 +273,7 @@ def probabilistic_rejection_sample(
|
||||
residual_logits,
|
||||
idx_mapping,
|
||||
temperature,
|
||||
seeds,
|
||||
seed,
|
||||
residual_pos,
|
||||
apply_temperature=False,
|
||||
)
|
||||
@@ -331,10 +305,7 @@ class RejectionSampler:
|
||||
num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
|
||||
|
||||
if self.use_strict_rejection_sampling:
|
||||
sampler_output = self.sampler(
|
||||
logits,
|
||||
input_batch,
|
||||
)
|
||||
sampler_output = self.sampler(logits, input_batch)
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
sampled, num_sampled = strict_rejection_sample(
|
||||
sampler_output.sampled_token_ids.view(-1),
|
||||
|
||||
Reference in New Issue
Block a user