[Model Runner V2] Some code simplification (#36929)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -57,10 +57,7 @@ SpeculativeMethod = Literal[
|
|||||||
EagleModelTypes,
|
EagleModelTypes,
|
||||||
NgramGPUTypes,
|
NgramGPUTypes,
|
||||||
]
|
]
|
||||||
RejectionSampleMethod = Literal[
|
RejectionSampleMethod = Literal["strict", "probabilistic"]
|
||||||
"strict",
|
|
||||||
"probabilistic",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ def _gumbel_sample_kernel(
|
|||||||
logits = logits.to(tl.float32)
|
logits = logits.to(tl.float32)
|
||||||
|
|
||||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||||
if (temp != 0.0) and APPLY_TEMPERATURE:
|
if temp != 0.0 and APPLY_TEMPERATURE:
|
||||||
# Apply temperature.
|
# Apply temperature.
|
||||||
# NOTE(woosuk): Match the behavior of _temperature_kernel.
|
# NOTE(woosuk): Match the behavior of _temperature_kernel.
|
||||||
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
|
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
|
||||||
@@ -127,18 +127,8 @@ def gumbel_sample(
|
|||||||
num_tokens, vocab_size = logits.shape
|
num_tokens, vocab_size = logits.shape
|
||||||
BLOCK_SIZE = 1024
|
BLOCK_SIZE = 1024
|
||||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||||
local_argmax = torch.empty(
|
local_argmax = logits.new_empty(num_tokens, num_blocks, dtype=torch.int64)
|
||||||
num_tokens,
|
local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float32)
|
||||||
num_blocks,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=logits.device,
|
|
||||||
)
|
|
||||||
local_max = torch.empty(
|
|
||||||
num_tokens,
|
|
||||||
num_blocks,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=logits.device,
|
|
||||||
)
|
|
||||||
_gumbel_sample_kernel[(num_tokens, num_blocks)](
|
_gumbel_sample_kernel[(num_tokens, num_blocks)](
|
||||||
local_argmax,
|
local_argmax,
|
||||||
local_argmax.stride(0),
|
local_argmax.stride(0),
|
||||||
|
|||||||
@@ -53,17 +53,8 @@ def strict_rejection_sample(
|
|||||||
num_speculative_steps,
|
num_speculative_steps,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
num_reqs = cu_num_logits.shape[0] - 1
|
num_reqs = cu_num_logits.shape[0] - 1
|
||||||
sampled = torch.empty(
|
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
|
||||||
num_reqs,
|
num_sampled = target_sampled.new_empty(num_reqs, dtype=torch.int32)
|
||||||
num_speculative_steps + 1,
|
|
||||||
dtype=target_sampled.dtype,
|
|
||||||
device=target_sampled.device,
|
|
||||||
)
|
|
||||||
num_sampled = torch.empty(
|
|
||||||
num_reqs,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=target_sampled.device,
|
|
||||||
)
|
|
||||||
_strict_rejection_sample_kernel[(num_reqs,)](
|
_strict_rejection_sample_kernel[(num_reqs,)](
|
||||||
sampled,
|
sampled,
|
||||||
sampled.stride(0),
|
sampled.stride(0),
|
||||||
@@ -216,12 +207,11 @@ def probabilistic_rejection_sample(
|
|||||||
pos: torch.Tensor,
|
pos: torch.Tensor,
|
||||||
# [num_reqs]
|
# [num_reqs]
|
||||||
idx_mapping: torch.Tensor,
|
idx_mapping: torch.Tensor,
|
||||||
temperature,
|
temperature: torch.Tensor,
|
||||||
seeds,
|
seed: torch.Tensor,
|
||||||
num_speculative_steps,
|
num_speculative_steps: int,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
num_reqs = cu_num_logits.shape[0] - 1
|
num_reqs = cu_num_logits.shape[0] - 1
|
||||||
device = target_logits.device
|
|
||||||
vocab_size = target_logits.shape[-1]
|
vocab_size = target_logits.shape[-1]
|
||||||
|
|
||||||
# Compute target and draft probs.
|
# Compute target and draft probs.
|
||||||
@@ -230,18 +220,11 @@ def probabilistic_rejection_sample(
|
|||||||
|
|
||||||
# Rejection sample.
|
# Rejection sample.
|
||||||
# [num_reqs, num_speculative_steps + 1]
|
# [num_reqs, num_speculative_steps + 1]
|
||||||
sampled = torch.empty(
|
sampled = draft_sampled.new_empty(
|
||||||
num_reqs,
|
num_reqs, num_speculative_steps + 1, dtype=torch.int64
|
||||||
num_speculative_steps + 1,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=device,
|
|
||||||
)
|
)
|
||||||
# [num_reqs]
|
# [num_reqs]
|
||||||
rejected_steps = torch.empty(
|
rejected_steps = sampled.new_empty(num_reqs)
|
||||||
num_reqs,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
_probabilistic_rejection_sample_kernel[(num_reqs,)](
|
_probabilistic_rejection_sample_kernel[(num_reqs,)](
|
||||||
sampled,
|
sampled,
|
||||||
sampled.stride(0),
|
sampled.stride(0),
|
||||||
@@ -255,25 +238,16 @@ def probabilistic_rejection_sample(
|
|||||||
cu_num_logits,
|
cu_num_logits,
|
||||||
pos,
|
pos,
|
||||||
idx_mapping,
|
idx_mapping,
|
||||||
seeds,
|
seed,
|
||||||
num_warps=1,
|
num_warps=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute the logits and positions to resample the rejected/bonus
|
# Compute the logits and positions to resample the rejected/bonus
|
||||||
# tokens from.
|
# tokens from.
|
||||||
# [num_reqs, vocab_size]
|
# [num_reqs, vocab_size]
|
||||||
residual_logits = torch.empty(
|
residual_logits = target_logits.new_empty(num_reqs, vocab_size)
|
||||||
num_reqs,
|
|
||||||
vocab_size,
|
|
||||||
dtype=target_logits.dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
# [num_reqs]
|
# [num_reqs]
|
||||||
residual_pos = torch.empty(
|
residual_pos = pos.new_empty(num_reqs)
|
||||||
num_reqs,
|
|
||||||
dtype=pos.dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
BLOCK_SIZE = 1024
|
BLOCK_SIZE = 1024
|
||||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||||
_compute_residual_logits_kernel[(num_reqs, num_blocks)](
|
_compute_residual_logits_kernel[(num_reqs, num_blocks)](
|
||||||
@@ -299,7 +273,7 @@ def probabilistic_rejection_sample(
|
|||||||
residual_logits,
|
residual_logits,
|
||||||
idx_mapping,
|
idx_mapping,
|
||||||
temperature,
|
temperature,
|
||||||
seeds,
|
seed,
|
||||||
residual_pos,
|
residual_pos,
|
||||||
apply_temperature=False,
|
apply_temperature=False,
|
||||||
)
|
)
|
||||||
@@ -331,10 +305,7 @@ class RejectionSampler:
|
|||||||
num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
|
num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
|
||||||
|
|
||||||
if self.use_strict_rejection_sampling:
|
if self.use_strict_rejection_sampling:
|
||||||
sampler_output = self.sampler(
|
sampler_output = self.sampler(logits, input_batch)
|
||||||
logits,
|
|
||||||
input_batch,
|
|
||||||
)
|
|
||||||
logprobs_tensors = sampler_output.logprobs_tensors
|
logprobs_tensors = sampler_output.logprobs_tensors
|
||||||
sampled, num_sampled = strict_rejection_sample(
|
sampled, num_sampled = strict_rejection_sample(
|
||||||
sampler_output.sampled_token_ids.view(-1),
|
sampler_output.sampled_token_ids.view(-1),
|
||||||
|
|||||||
Reference in New Issue
Block a user