[Model Runner V2] Add sample/ directory and reorganize files (#29719)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -47,13 +47,18 @@ from vllm.v1.worker.gpu.input_batch import (
|
|||||||
prepare_pos_seq_lens,
|
prepare_pos_seq_lens,
|
||||||
prepare_prefill_inputs,
|
prepare_prefill_inputs,
|
||||||
)
|
)
|
||||||
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
|
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
|
||||||
|
from vllm.v1.worker.gpu.sample.metadata import (
|
||||||
|
SamplingMetadata,
|
||||||
|
expand_sampling_metadata,
|
||||||
|
)
|
||||||
|
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||||
from vllm.v1.worker.gpu.spec_decode import init_speculator
|
from vllm.v1.worker.gpu.spec_decode import init_speculator
|
||||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
|
from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
|
||||||
get_num_rejected,
|
get_num_rejected,
|
||||||
rejection_sample,
|
rejection_sample,
|
||||||
)
|
)
|
||||||
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
|
from vllm.v1.worker.gpu.states import RequestState
|
||||||
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
|
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
@@ -890,8 +895,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
input_batch.idx_mapping, input_batch.idx_mapping_np, pos
|
input_batch.idx_mapping, input_batch.idx_mapping_np, pos
|
||||||
)
|
)
|
||||||
if input_batch.num_draft_tokens > 0:
|
if input_batch.num_draft_tokens > 0:
|
||||||
sampling_metadata = self.req_states.expand_sampling_metadata(
|
sampling_metadata = expand_sampling_metadata(
|
||||||
sampling_metadata, input_batch.cu_num_logits
|
sampling_metadata,
|
||||||
|
input_batch.cu_num_logits,
|
||||||
|
max_expand_len=self.num_speculative_steps + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
|
|||||||
0
vllm/v1/worker/gpu/sample/__init__.py
Normal file
0
vllm/v1/worker/gpu/sample/__init__.py
Normal file
100
vllm/v1/worker/gpu/sample/gumbel.py
Normal file
100
vllm/v1/worker/gpu/sample/gumbel.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _gumbel_sample_kernel(
|
||||||
|
local_argmax_ptr,
|
||||||
|
local_argmax_stride,
|
||||||
|
local_max_ptr,
|
||||||
|
local_max_stride,
|
||||||
|
logits_ptr,
|
||||||
|
logits_stride,
|
||||||
|
seeds_ptr,
|
||||||
|
pos_ptr,
|
||||||
|
temp_ptr,
|
||||||
|
vocab_size,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
APPLY_TEMPERATURE: tl.constexpr,
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
block_idx = tl.program_id(1)
|
||||||
|
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = block < vocab_size
|
||||||
|
logits = tl.load(
|
||||||
|
logits_ptr + req_idx * logits_stride + block,
|
||||||
|
mask=mask,
|
||||||
|
other=float("-inf"),
|
||||||
|
)
|
||||||
|
logits = logits.to(tl.float32)
|
||||||
|
|
||||||
|
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
|
||||||
|
if temp != 0.0:
|
||||||
|
# Calculate the seed for gumbel noise.
|
||||||
|
seed = tl.load(seeds_ptr + req_idx)
|
||||||
|
pos = tl.load(pos_ptr + req_idx)
|
||||||
|
gumbel_seed = tl.randint(seed, pos)
|
||||||
|
|
||||||
|
# Generate gumbel noise.
|
||||||
|
r = tl.rand(gumbel_seed, block).to(tl.float64)
|
||||||
|
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
||||||
|
gumbel_noise = gumbel_noise.to(tl.float32)
|
||||||
|
|
||||||
|
# Apply temperature.
|
||||||
|
if APPLY_TEMPERATURE:
|
||||||
|
# NOTE(woosuk): Use div_rn to match the behavior of torch.
|
||||||
|
logits = tl.div_rn(logits, temp)
|
||||||
|
|
||||||
|
# Apply gumbel noise.
|
||||||
|
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||||
|
|
||||||
|
idx = tl.argmax(logits, axis=0)
|
||||||
|
token_id = block_idx * BLOCK_SIZE + idx
|
||||||
|
value = tl.max(logits, axis=0)
|
||||||
|
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
|
||||||
|
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
|
||||||
|
|
||||||
|
|
||||||
|
def gumbel_sample(
|
||||||
|
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||||
|
temperature: torch.Tensor, # [num_reqs]
|
||||||
|
seed: torch.Tensor, # [num_reqs]
|
||||||
|
pos: torch.Tensor, # [num_reqs]
|
||||||
|
apply_temperature: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
num_reqs, vocab_size = logits.shape
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||||
|
local_argmax = torch.empty(
|
||||||
|
num_reqs,
|
||||||
|
num_blocks,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=logits.device,
|
||||||
|
)
|
||||||
|
local_max = torch.empty(
|
||||||
|
num_reqs,
|
||||||
|
num_blocks,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=logits.device,
|
||||||
|
)
|
||||||
|
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
||||||
|
local_argmax,
|
||||||
|
local_argmax.stride(0),
|
||||||
|
local_max,
|
||||||
|
local_max.stride(0),
|
||||||
|
logits,
|
||||||
|
logits.stride(0),
|
||||||
|
seed,
|
||||||
|
pos,
|
||||||
|
temperature,
|
||||||
|
vocab_size,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
APPLY_TEMPERATURE=apply_temperature,
|
||||||
|
)
|
||||||
|
# NOTE(woosuk): Use int64 for later indexing.
|
||||||
|
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
||||||
|
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
|
||||||
|
return sampled
|
||||||
167
vllm/v1/worker/gpu/sample/logprob.py
Normal file
167
vllm/v1/worker/gpu/sample/logprob.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
from vllm.v1.outputs import LogprobsTensors
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _topk_log_softmax_kernel(
|
||||||
|
output_ptr,
|
||||||
|
logits_ptr,
|
||||||
|
logits_stride,
|
||||||
|
topk_ids_ptr,
|
||||||
|
topk,
|
||||||
|
vocab_size,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
PADDED_TOPK: tl.constexpr,
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
row_ptr = logits_ptr + req_idx * logits_stride
|
||||||
|
|
||||||
|
max_val = float("-inf")
|
||||||
|
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||||
|
block = i + tl.arange(0, BLOCK_SIZE)
|
||||||
|
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||||
|
max_val = tl.max(tl.maximum(logits, max_val))
|
||||||
|
max_val = max_val.to(tl.float32) # type: ignore
|
||||||
|
|
||||||
|
se = 0.0
|
||||||
|
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||||
|
block = i + tl.arange(0, BLOCK_SIZE)
|
||||||
|
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
||||||
|
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
|
||||||
|
logits = logits.to(tl.float32)
|
||||||
|
e = tl.exp(logits - max_val)
|
||||||
|
e = tl.where(block < vocab_size, e, 0.0)
|
||||||
|
se += tl.sum(e)
|
||||||
|
lse = tl.log(se)
|
||||||
|
|
||||||
|
k_offset = tl.arange(0, PADDED_TOPK)
|
||||||
|
k_mask = k_offset < topk
|
||||||
|
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
|
||||||
|
|
||||||
|
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
|
||||||
|
logits = logits.to(tl.float32)
|
||||||
|
o = logits - max_val - lse
|
||||||
|
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _ranks_kernel(
|
||||||
|
output_ptr,
|
||||||
|
logits_ptr,
|
||||||
|
logits_stride,
|
||||||
|
token_ids_ptr,
|
||||||
|
vocab_size,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
row_ptr = logits_ptr + req_idx * logits_stride
|
||||||
|
|
||||||
|
token_id = tl.load(token_ids_ptr + req_idx)
|
||||||
|
x = tl.load(row_ptr + token_id)
|
||||||
|
|
||||||
|
n = 0
|
||||||
|
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||||
|
block = i + tl.arange(0, BLOCK_SIZE)
|
||||||
|
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
||||||
|
n += tl.sum((logits > x).to(tl.int32))
|
||||||
|
tl.store(output_ptr + req_idx, n)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_token_logprobs(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
token_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size = logits.shape[0]
|
||||||
|
vocab_size = logits.shape[1]
|
||||||
|
token_ids = token_ids.to(torch.int64)
|
||||||
|
num_logprobs = token_ids.shape[1]
|
||||||
|
logprobs = torch.empty(
|
||||||
|
batch_size,
|
||||||
|
num_logprobs,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=logits.device,
|
||||||
|
)
|
||||||
|
_topk_log_softmax_kernel[(batch_size,)](
|
||||||
|
logprobs,
|
||||||
|
logits,
|
||||||
|
logits.stride(0),
|
||||||
|
token_ids,
|
||||||
|
num_logprobs,
|
||||||
|
vocab_size,
|
||||||
|
BLOCK_SIZE=1024, # type: ignore
|
||||||
|
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
|
||||||
|
)
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def compute_topk_logprobs(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
num_logprobs: int,
|
||||||
|
sampled_token_ids: torch.Tensor,
|
||||||
|
) -> LogprobsTensors:
|
||||||
|
assert num_logprobs >= 0
|
||||||
|
batch_size, vocab_size = logits.shape
|
||||||
|
if num_logprobs == 0:
|
||||||
|
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
||||||
|
logprob_token_ids = torch.cat(
|
||||||
|
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
||||||
|
# logprobs tensor. Instead, we only compute and return the logprobs of
|
||||||
|
# the topk + 1 tokens.
|
||||||
|
logprobs = compute_token_logprobs(logits, logprob_token_ids)
|
||||||
|
token_ranks = torch.empty(
|
||||||
|
batch_size,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=logits.device,
|
||||||
|
)
|
||||||
|
_ranks_kernel[(batch_size,)](
|
||||||
|
token_ranks,
|
||||||
|
logits,
|
||||||
|
logits.stride(0),
|
||||||
|
sampled_token_ids,
|
||||||
|
vocab_size,
|
||||||
|
BLOCK_SIZE=8192, # type: ignore
|
||||||
|
)
|
||||||
|
return LogprobsTensors(
|
||||||
|
logprob_token_ids=logprob_token_ids,
|
||||||
|
logprobs=logprobs,
|
||||||
|
selected_token_ranks=token_ranks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_prompt_logprobs(
|
||||||
|
prompt_token_ids: torch.Tensor,
|
||||||
|
prompt_hidden_states: torch.Tensor,
|
||||||
|
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Since materializing the full prompt logits can take too much memory,
|
||||||
|
# we compute it in chunks.
|
||||||
|
CHUNK_SIZE = 1024
|
||||||
|
logprobs = []
|
||||||
|
ranks = []
|
||||||
|
prompt_token_ids = prompt_token_ids.to(torch.int64)
|
||||||
|
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
|
||||||
|
end_idx = start_idx + CHUNK_SIZE
|
||||||
|
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
|
||||||
|
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
|
||||||
|
prompt_logprobs = compute_topk_logprobs(
|
||||||
|
prompt_logits,
|
||||||
|
0, # num_logprobs
|
||||||
|
prompt_token_ids[start_idx:end_idx],
|
||||||
|
)
|
||||||
|
logprobs.append(prompt_logprobs.logprobs)
|
||||||
|
ranks.append(prompt_logprobs.selected_token_ranks)
|
||||||
|
|
||||||
|
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
|
||||||
|
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
|
||||||
|
return logprobs, ranks
|
||||||
179
vllm/v1/worker/gpu/sample/metadata.py
Normal file
179
vllm/v1/worker/gpu/sample/metadata.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SamplingMetadata:
|
||||||
|
temperature: torch.Tensor
|
||||||
|
|
||||||
|
top_p: torch.Tensor | None
|
||||||
|
top_k: torch.Tensor | None
|
||||||
|
|
||||||
|
repetition_penalty: torch.Tensor
|
||||||
|
frequency_penalty: torch.Tensor
|
||||||
|
presence_penalty: torch.Tensor
|
||||||
|
|
||||||
|
seeds: torch.Tensor
|
||||||
|
pos: torch.Tensor
|
||||||
|
|
||||||
|
# None means no logprobs, 0 means sampled token logprobs only
|
||||||
|
max_num_logprobs: int | None
|
||||||
|
|
||||||
|
# For penalties
|
||||||
|
idx_mapping: torch.Tensor
|
||||||
|
prompt_bin_counts: torch.Tensor
|
||||||
|
output_bin_counts: torch.Tensor
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_dummy(
|
||||||
|
cls,
|
||||||
|
num_reqs: int,
|
||||||
|
device: torch.device,
|
||||||
|
) -> "SamplingMetadata":
|
||||||
|
assert num_reqs > 0
|
||||||
|
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||||
|
temperature[0] = 0.5
|
||||||
|
# TODO(woosuk): Use top-p and top-k for dummy sampler.
|
||||||
|
# Currently, they are disabled because of memory usage.
|
||||||
|
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
|
||||||
|
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
|
||||||
|
top_p = None
|
||||||
|
top_k = None
|
||||||
|
# NOTE(woosuk): We must set penalties to their default values to make sure
|
||||||
|
# the penalties kernel does not touch the placeholder bin_counts tensors.
|
||||||
|
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
|
||||||
|
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||||
|
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||||
|
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||||
|
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||||
|
max_num_logprobs = 20
|
||||||
|
|
||||||
|
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
||||||
|
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
|
||||||
|
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
|
||||||
|
# specialization and re-compilation at runtime.
|
||||||
|
prompt_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
|
||||||
|
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
seeds=seeds,
|
||||||
|
pos=pos,
|
||||||
|
max_num_logprobs=max_num_logprobs,
|
||||||
|
idx_mapping=idx_mapping,
|
||||||
|
prompt_bin_counts=prompt_bin_counts,
|
||||||
|
output_bin_counts=output_bin_counts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
|
||||||
|
@triton.jit
|
||||||
|
def _expand_sampling_metadata_kernel(
|
||||||
|
temp_ptr,
|
||||||
|
expanded_temp_ptr,
|
||||||
|
top_p_ptr,
|
||||||
|
expanded_top_p_ptr,
|
||||||
|
top_k_ptr,
|
||||||
|
expanded_top_k_ptr,
|
||||||
|
rep_penalty_ptr,
|
||||||
|
expanded_rep_penalty_ptr,
|
||||||
|
freq_penalty_ptr,
|
||||||
|
expanded_freq_penalty_ptr,
|
||||||
|
pres_penalty_ptr,
|
||||||
|
expanded_pres_penalty_ptr,
|
||||||
|
seeds_ptr,
|
||||||
|
expanded_seeds_ptr,
|
||||||
|
cu_num_logits_ptr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||||
|
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||||
|
num_tokens = end_idx - start_idx
|
||||||
|
|
||||||
|
block = tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = block < num_tokens
|
||||||
|
|
||||||
|
temp = tl.load(temp_ptr + req_idx)
|
||||||
|
tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
|
||||||
|
|
||||||
|
if top_p_ptr is not None:
|
||||||
|
top_p = tl.load(top_p_ptr + req_idx)
|
||||||
|
tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
|
||||||
|
|
||||||
|
if top_k_ptr is not None:
|
||||||
|
top_k = tl.load(top_k_ptr + req_idx)
|
||||||
|
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
|
||||||
|
|
||||||
|
rep_penalty = tl.load(rep_penalty_ptr + req_idx)
|
||||||
|
tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask)
|
||||||
|
|
||||||
|
freq_penalty = tl.load(freq_penalty_ptr + req_idx)
|
||||||
|
tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask)
|
||||||
|
|
||||||
|
pres_penalty = tl.load(pres_penalty_ptr + req_idx)
|
||||||
|
tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask)
|
||||||
|
|
||||||
|
seed = tl.load(seeds_ptr + req_idx)
|
||||||
|
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def expand_sampling_metadata(
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
cu_num_logits: torch.Tensor,
|
||||||
|
max_expand_len: int,
|
||||||
|
) -> SamplingMetadata:
|
||||||
|
total_num_logits = sampling_metadata.pos.shape[0]
|
||||||
|
create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
|
||||||
|
expanded_temp = create_empty(sampling_metadata.temperature)
|
||||||
|
expanded_top_p = create_empty(sampling_metadata.top_p)
|
||||||
|
expanded_top_k = create_empty(sampling_metadata.top_k)
|
||||||
|
expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty)
|
||||||
|
expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty)
|
||||||
|
expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty)
|
||||||
|
expanded_seeds = create_empty(sampling_metadata.seeds)
|
||||||
|
|
||||||
|
num_reqs = cu_num_logits.shape[0] - 1
|
||||||
|
_expand_sampling_metadata_kernel[(num_reqs,)](
|
||||||
|
sampling_metadata.temperature,
|
||||||
|
expanded_temp,
|
||||||
|
sampling_metadata.top_p,
|
||||||
|
expanded_top_p,
|
||||||
|
sampling_metadata.top_k,
|
||||||
|
expanded_top_k,
|
||||||
|
sampling_metadata.repetition_penalty,
|
||||||
|
expanded_repetition_penalty,
|
||||||
|
sampling_metadata.frequency_penalty,
|
||||||
|
expanded_frequency_penalty,
|
||||||
|
sampling_metadata.presence_penalty,
|
||||||
|
expanded_presence_penalty,
|
||||||
|
sampling_metadata.seeds,
|
||||||
|
expanded_seeds,
|
||||||
|
cu_num_logits,
|
||||||
|
BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
|
||||||
|
)
|
||||||
|
return SamplingMetadata(
|
||||||
|
temperature=expanded_temp,
|
||||||
|
top_p=expanded_top_p,
|
||||||
|
top_k=expanded_top_k,
|
||||||
|
seeds=expanded_seeds,
|
||||||
|
repetition_penalty=expanded_repetition_penalty,
|
||||||
|
frequency_penalty=expanded_frequency_penalty,
|
||||||
|
presence_penalty=expanded_presence_penalty,
|
||||||
|
pos=sampling_metadata.pos,
|
||||||
|
max_num_logprobs=sampling_metadata.max_num_logprobs,
|
||||||
|
# TODO(woosuk): Support penalties with spec decoding.
|
||||||
|
idx_mapping=sampling_metadata.idx_mapping,
|
||||||
|
prompt_bin_counts=sampling_metadata.prompt_bin_counts,
|
||||||
|
output_bin_counts=sampling_metadata.output_bin_counts,
|
||||||
|
)
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -83,3 +83,49 @@ def apply_penalties(logits: torch.Tensor, sampling_metadata: SamplingMetadata) -
|
|||||||
vocab_size,
|
vocab_size,
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
|
||||||
|
def _bincount_kernel(
|
||||||
|
prefill_token_ids_ptr,
|
||||||
|
prefill_len,
|
||||||
|
prompt_len,
|
||||||
|
prompt_bin_counts_ptr,
|
||||||
|
output_bin_counts_ptr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
block_idx = tl.program_id(0)
|
||||||
|
if block_idx * BLOCK_SIZE >= prefill_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
if block_idx * BLOCK_SIZE < prompt_len:
|
||||||
|
mask = block < prompt_len
|
||||||
|
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
|
||||||
|
tl.atomic_add(prompt_bin_counts_ptr + prefill_tokens, 1, mask=mask)
|
||||||
|
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
|
||||||
|
mask = block < prefill_len
|
||||||
|
mask &= block >= prompt_len
|
||||||
|
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
|
||||||
|
tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def bincount(
|
||||||
|
prefill_token_ids: torch.Tensor,
|
||||||
|
prefill_len: int,
|
||||||
|
prompt_len: int,
|
||||||
|
prompt_bin_counts: torch.Tensor,
|
||||||
|
output_bin_counts: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
prompt_bin_counts.zero_()
|
||||||
|
output_bin_counts.zero_()
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
|
||||||
|
_bincount_kernel[(num_blocks,)](
|
||||||
|
prefill_token_ids,
|
||||||
|
prefill_len,
|
||||||
|
prompt_len,
|
||||||
|
prompt_bin_counts,
|
||||||
|
output_bin_counts,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
)
|
||||||
79
vllm/v1/worker/gpu/sample/sampler.py
Normal file
79
vllm/v1/worker/gpu/sample/sampler.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config.model import LogprobsMode
|
||||||
|
from vllm.v1.outputs import SamplerOutput
|
||||||
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
|
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||||
|
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||||
|
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.worker.gpu.sample.penalties import apply_penalties
|
||||||
|
|
||||||
|
|
||||||
|
class Sampler:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logprobs_mode: LogprobsMode = "raw_logprobs",
|
||||||
|
):
|
||||||
|
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
|
||||||
|
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
||||||
|
self.logprobs_mode = logprobs_mode
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> SamplerOutput:
|
||||||
|
if sampling_metadata.max_num_logprobs is not None:
|
||||||
|
if self.logprobs_mode == "processed_logprobs":
|
||||||
|
sampled, logits = self.sample(
|
||||||
|
logits, sampling_metadata, return_logits=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.logprobs_mode == "raw_logprobs"
|
||||||
|
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
|
||||||
|
|
||||||
|
logprobs_tensors = compute_topk_logprobs(
|
||||||
|
logits,
|
||||||
|
sampling_metadata.max_num_logprobs,
|
||||||
|
sampled,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
|
||||||
|
logprobs_tensors = None
|
||||||
|
|
||||||
|
# These are GPU tensors.
|
||||||
|
sampler_output = SamplerOutput(
|
||||||
|
# The sampled tokens are expanded to 2D tensor with shape
|
||||||
|
# [num_requests, 1], where each row represents one generated
|
||||||
|
# token per request.
|
||||||
|
sampled_token_ids=sampled.view(-1, 1),
|
||||||
|
logprobs_tensors=logprobs_tensors,
|
||||||
|
)
|
||||||
|
return sampler_output
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
return_logits: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
is_greedy = sampling_metadata.temperature == 0
|
||||||
|
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
||||||
|
logits = logits / temp.view(-1, 1)
|
||||||
|
logits = apply_top_k_top_p(
|
||||||
|
logits, sampling_metadata.top_k, sampling_metadata.top_p
|
||||||
|
)
|
||||||
|
# Apply penalties in place.
|
||||||
|
apply_penalties(logits, sampling_metadata)
|
||||||
|
|
||||||
|
sampled = gumbel_sample(
|
||||||
|
logits,
|
||||||
|
sampling_metadata.temperature,
|
||||||
|
sampling_metadata.seeds,
|
||||||
|
sampling_metadata.pos,
|
||||||
|
apply_temperature=False,
|
||||||
|
)
|
||||||
|
return sampled, logits if return_logits else None
|
||||||
@@ -1,333 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.config.model import LogprobsMode
|
|
||||||
from vllm.triton_utils import tl, triton
|
|
||||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
|
||||||
from vllm.v1.worker.gpu.penalties import apply_penalties
|
|
||||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
|
||||||
|
|
||||||
|
|
||||||
class Sampler:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
logprobs_mode: LogprobsMode = "raw_logprobs",
|
|
||||||
):
|
|
||||||
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
|
|
||||||
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
|
|
||||||
self.logprobs_mode = logprobs_mode
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> SamplerOutput:
|
|
||||||
if sampling_metadata.max_num_logprobs is not None:
|
|
||||||
if self.logprobs_mode == "processed_logprobs":
|
|
||||||
sampled, logits = self.sample(
|
|
||||||
logits, sampling_metadata, return_logits=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert self.logprobs_mode == "raw_logprobs"
|
|
||||||
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
|
|
||||||
|
|
||||||
logprobs_tensors = compute_topk_logprobs(
|
|
||||||
logits,
|
|
||||||
sampling_metadata.max_num_logprobs,
|
|
||||||
sampled,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
|
|
||||||
logprobs_tensors = None
|
|
||||||
|
|
||||||
# These are GPU tensors.
|
|
||||||
sampler_output = SamplerOutput(
|
|
||||||
# The sampled tokens are expanded to 2D tensor with shape
|
|
||||||
# [num_requests, 1], where each row represents one generated
|
|
||||||
# token per request.
|
|
||||||
sampled_token_ids=sampled.view(-1, 1),
|
|
||||||
logprobs_tensors=logprobs_tensors,
|
|
||||||
)
|
|
||||||
return sampler_output
|
|
||||||
|
|
||||||
def sample(
|
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
return_logits: bool = False,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
||||||
is_greedy = sampling_metadata.temperature == 0
|
|
||||||
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
|
||||||
logits = logits / temp.view(-1, 1)
|
|
||||||
logits = apply_top_k_top_p(
|
|
||||||
logits, sampling_metadata.top_k, sampling_metadata.top_p
|
|
||||||
)
|
|
||||||
# Apply penalties in place.
|
|
||||||
apply_penalties(logits, sampling_metadata)
|
|
||||||
|
|
||||||
sampled = gumbel_sample(
|
|
||||||
logits,
|
|
||||||
sampling_metadata.temperature,
|
|
||||||
sampling_metadata.seeds,
|
|
||||||
sampling_metadata.pos,
|
|
||||||
apply_temperature=False,
|
|
||||||
)
|
|
||||||
return sampled, logits if return_logits else None
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _gumbel_sample_kernel(
|
|
||||||
local_argmax_ptr,
|
|
||||||
local_argmax_stride,
|
|
||||||
local_max_ptr,
|
|
||||||
local_max_stride,
|
|
||||||
logits_ptr,
|
|
||||||
logits_stride,
|
|
||||||
seeds_ptr,
|
|
||||||
pos_ptr,
|
|
||||||
temp_ptr,
|
|
||||||
vocab_size,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
APPLY_TEMPERATURE: tl.constexpr,
|
|
||||||
):
|
|
||||||
req_idx = tl.program_id(0)
|
|
||||||
block_idx = tl.program_id(1)
|
|
||||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = block < vocab_size
|
|
||||||
logits = tl.load(
|
|
||||||
logits_ptr + req_idx * logits_stride + block,
|
|
||||||
mask=mask,
|
|
||||||
other=float("-inf"),
|
|
||||||
)
|
|
||||||
logits = logits.to(tl.float32)
|
|
||||||
|
|
||||||
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
|
|
||||||
if temp != 0.0:
|
|
||||||
# Calculate the seed for gumbel noise.
|
|
||||||
seed = tl.load(seeds_ptr + req_idx)
|
|
||||||
pos = tl.load(pos_ptr + req_idx)
|
|
||||||
gumbel_seed = tl.randint(seed, pos)
|
|
||||||
|
|
||||||
# Generate gumbel noise.
|
|
||||||
r = tl.rand(gumbel_seed, block).to(tl.float64)
|
|
||||||
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
|
||||||
gumbel_noise = gumbel_noise.to(tl.float32)
|
|
||||||
|
|
||||||
# Apply temperature.
|
|
||||||
if APPLY_TEMPERATURE:
|
|
||||||
# NOTE(woosuk): Use div_rn to match the behavior of torch.
|
|
||||||
logits = tl.div_rn(logits, temp)
|
|
||||||
|
|
||||||
# Apply gumbel noise.
|
|
||||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
|
||||||
|
|
||||||
idx = tl.argmax(logits, axis=0)
|
|
||||||
token_id = block_idx * BLOCK_SIZE + idx
|
|
||||||
value = tl.max(logits, axis=0)
|
|
||||||
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
|
|
||||||
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
|
|
||||||
|
|
||||||
|
|
||||||
def gumbel_sample(
|
|
||||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
|
||||||
temperature: torch.Tensor, # [num_reqs]
|
|
||||||
seed: torch.Tensor, # [num_reqs]
|
|
||||||
pos: torch.Tensor, # [num_reqs]
|
|
||||||
apply_temperature: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
num_reqs, vocab_size = logits.shape
|
|
||||||
BLOCK_SIZE = 1024
|
|
||||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
|
||||||
local_argmax = torch.empty(
|
|
||||||
num_reqs,
|
|
||||||
num_blocks,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=logits.device,
|
|
||||||
)
|
|
||||||
local_max = torch.empty(
|
|
||||||
num_reqs,
|
|
||||||
num_blocks,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=logits.device,
|
|
||||||
)
|
|
||||||
_gumbel_sample_kernel[(num_reqs, num_blocks)](
|
|
||||||
local_argmax,
|
|
||||||
local_argmax.stride(0),
|
|
||||||
local_max,
|
|
||||||
local_max.stride(0),
|
|
||||||
logits,
|
|
||||||
logits.stride(0),
|
|
||||||
seed,
|
|
||||||
pos,
|
|
||||||
temperature,
|
|
||||||
vocab_size,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
|
||||||
APPLY_TEMPERATURE=apply_temperature,
|
|
||||||
)
|
|
||||||
# NOTE(woosuk): Use int64 for later indexing.
|
|
||||||
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
|
|
||||||
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
|
|
||||||
return sampled
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _topk_log_softmax_kernel(
|
|
||||||
output_ptr,
|
|
||||||
logits_ptr,
|
|
||||||
logits_stride,
|
|
||||||
topk_ids_ptr,
|
|
||||||
topk,
|
|
||||||
vocab_size,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
PADDED_TOPK: tl.constexpr,
|
|
||||||
):
|
|
||||||
req_idx = tl.program_id(0)
|
|
||||||
row_ptr = logits_ptr + req_idx * logits_stride
|
|
||||||
|
|
||||||
max_val = float("-inf")
|
|
||||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
|
||||||
block = i + tl.arange(0, BLOCK_SIZE)
|
|
||||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
|
||||||
max_val = tl.max(tl.maximum(logits, max_val))
|
|
||||||
max_val = max_val.to(tl.float32) # type: ignore
|
|
||||||
|
|
||||||
se = 0.0
|
|
||||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
|
||||||
block = i + tl.arange(0, BLOCK_SIZE)
|
|
||||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
|
|
||||||
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
|
|
||||||
logits = logits.to(tl.float32)
|
|
||||||
e = tl.exp(logits - max_val)
|
|
||||||
e = tl.where(block < vocab_size, e, 0.0)
|
|
||||||
se += tl.sum(e)
|
|
||||||
lse = tl.log(se)
|
|
||||||
|
|
||||||
k_offset = tl.arange(0, PADDED_TOPK)
|
|
||||||
k_mask = k_offset < topk
|
|
||||||
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
|
|
||||||
|
|
||||||
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
|
|
||||||
logits = logits.to(tl.float32)
|
|
||||||
o = logits - max_val - lse
|
|
||||||
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _ranks_kernel(
|
|
||||||
output_ptr,
|
|
||||||
logits_ptr,
|
|
||||||
logits_stride,
|
|
||||||
token_ids_ptr,
|
|
||||||
vocab_size,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
req_idx = tl.program_id(0)
|
|
||||||
row_ptr = logits_ptr + req_idx * logits_stride
|
|
||||||
|
|
||||||
token_id = tl.load(token_ids_ptr + req_idx)
|
|
||||||
x = tl.load(row_ptr + token_id)
|
|
||||||
|
|
||||||
n = 0
|
|
||||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
|
||||||
block = i + tl.arange(0, BLOCK_SIZE)
|
|
||||||
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
|
|
||||||
n += tl.sum((logits > x).to(tl.int32))
|
|
||||||
tl.store(output_ptr + req_idx, n)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_token_logprobs(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
token_ids: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
batch_size = logits.shape[0]
|
|
||||||
vocab_size = logits.shape[1]
|
|
||||||
token_ids = token_ids.to(torch.int64)
|
|
||||||
num_logprobs = token_ids.shape[1]
|
|
||||||
logprobs = torch.empty(
|
|
||||||
batch_size,
|
|
||||||
num_logprobs,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=logits.device,
|
|
||||||
)
|
|
||||||
_topk_log_softmax_kernel[(batch_size,)](
|
|
||||||
logprobs,
|
|
||||||
logits,
|
|
||||||
logits.stride(0),
|
|
||||||
token_ids,
|
|
||||||
num_logprobs,
|
|
||||||
vocab_size,
|
|
||||||
BLOCK_SIZE=1024, # type: ignore
|
|
||||||
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
|
|
||||||
)
|
|
||||||
return logprobs
|
|
||||||
|
|
||||||
|
|
||||||
def compute_topk_logprobs(
|
|
||||||
logits: torch.Tensor,
|
|
||||||
num_logprobs: int,
|
|
||||||
sampled_token_ids: torch.Tensor,
|
|
||||||
) -> LogprobsTensors:
|
|
||||||
assert num_logprobs >= 0
|
|
||||||
batch_size, vocab_size = logits.shape
|
|
||||||
if num_logprobs == 0:
|
|
||||||
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
|
|
||||||
else:
|
|
||||||
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
|
|
||||||
logprob_token_ids = torch.cat(
|
|
||||||
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
|
|
||||||
# logprobs tensor. Instead, we only compute and return the logprobs of
|
|
||||||
# the topk + 1 tokens.
|
|
||||||
logprobs = compute_token_logprobs(logits, logprob_token_ids)
|
|
||||||
token_ranks = torch.empty(
|
|
||||||
batch_size,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=logits.device,
|
|
||||||
)
|
|
||||||
_ranks_kernel[(batch_size,)](
|
|
||||||
token_ranks,
|
|
||||||
logits,
|
|
||||||
logits.stride(0),
|
|
||||||
sampled_token_ids,
|
|
||||||
vocab_size,
|
|
||||||
BLOCK_SIZE=8192, # type: ignore
|
|
||||||
)
|
|
||||||
return LogprobsTensors(
|
|
||||||
logprob_token_ids=logprob_token_ids,
|
|
||||||
logprobs=logprobs,
|
|
||||||
selected_token_ranks=token_ranks,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_prompt_logprobs(
|
|
||||||
prompt_token_ids: torch.Tensor,
|
|
||||||
prompt_hidden_states: torch.Tensor,
|
|
||||||
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# Since materializing the full prompt logits can take too much memory,
|
|
||||||
# we compute it in chunks.
|
|
||||||
CHUNK_SIZE = 1024
|
|
||||||
logprobs = []
|
|
||||||
ranks = []
|
|
||||||
prompt_token_ids = prompt_token_ids.to(torch.int64)
|
|
||||||
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
|
|
||||||
end_idx = start_idx + CHUNK_SIZE
|
|
||||||
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
|
|
||||||
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
|
|
||||||
prompt_logprobs = compute_topk_logprobs(
|
|
||||||
prompt_logits,
|
|
||||||
0, # num_logprobs
|
|
||||||
prompt_token_ids[start_idx:end_idx],
|
|
||||||
)
|
|
||||||
logprobs.append(prompt_logprobs.logprobs)
|
|
||||||
ranks.append(prompt_logprobs.selected_token_ranks)
|
|
||||||
|
|
||||||
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
|
|
||||||
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
|
|
||||||
return logprobs, ranks
|
|
||||||
@@ -18,9 +18,9 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
|
|||||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||||
from vllm.v1.worker.gpu.sampler import gumbel_sample
|
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||||
|
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager
|
from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager
|
||||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -7,86 +7,18 @@ import torch
|
|||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.triton_utils import tl, triton
|
|
||||||
from vllm.utils.platform_utils import is_uva_available
|
from vllm.utils.platform_utils import is_uva_available
|
||||||
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
|
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
|
||||||
from vllm.v1.outputs import LogprobsTensors
|
from vllm.v1.outputs import LogprobsTensors
|
||||||
from vllm.v1.utils import CpuGpuBuffer
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
|
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.worker.gpu.sample.penalties import bincount
|
||||||
|
|
||||||
_NP_INT64_MIN = np.iinfo(np.int64).min
|
_NP_INT64_MIN = np.iinfo(np.int64).min
|
||||||
_NP_INT64_MAX = np.iinfo(np.int64).max
|
_NP_INT64_MAX = np.iinfo(np.int64).max
|
||||||
NO_LORA_ID = 0
|
NO_LORA_ID = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SamplingMetadata:
|
|
||||||
temperature: torch.Tensor
|
|
||||||
|
|
||||||
top_p: torch.Tensor | None
|
|
||||||
top_k: torch.Tensor | None
|
|
||||||
|
|
||||||
repetition_penalty: torch.Tensor
|
|
||||||
frequency_penalty: torch.Tensor
|
|
||||||
presence_penalty: torch.Tensor
|
|
||||||
|
|
||||||
seeds: torch.Tensor
|
|
||||||
pos: torch.Tensor
|
|
||||||
|
|
||||||
# None means no logprobs, 0 means sampled token logprobs only
|
|
||||||
max_num_logprobs: int | None
|
|
||||||
|
|
||||||
# For penalties
|
|
||||||
idx_mapping: torch.Tensor
|
|
||||||
prompt_bin_counts: torch.Tensor
|
|
||||||
output_bin_counts: torch.Tensor
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_dummy(
|
|
||||||
cls,
|
|
||||||
num_reqs: int,
|
|
||||||
device: torch.device,
|
|
||||||
) -> "SamplingMetadata":
|
|
||||||
assert num_reqs > 0
|
|
||||||
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
|
||||||
temperature[0] = 0.5
|
|
||||||
# TODO(woosuk): Use top-p and top-k for dummy sampler.
|
|
||||||
# Currently, they are disabled because of memory usage.
|
|
||||||
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
|
|
||||||
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
|
|
||||||
top_p = None
|
|
||||||
top_k = None
|
|
||||||
# NOTE(woosuk): We must set penalties to their default values to make sure
|
|
||||||
# the penalties kernel does not touch the placeholder bin_counts tensors.
|
|
||||||
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
|
|
||||||
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
|
||||||
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
|
||||||
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
|
||||||
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
|
||||||
max_num_logprobs = 20
|
|
||||||
|
|
||||||
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
|
|
||||||
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
|
|
||||||
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
|
|
||||||
# specialization and re-compilation at runtime.
|
|
||||||
prompt_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
|
|
||||||
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
repetition_penalty=repetition_penalty,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
seeds=seeds,
|
|
||||||
pos=pos,
|
|
||||||
max_num_logprobs=max_num_logprobs,
|
|
||||||
idx_mapping=idx_mapping,
|
|
||||||
prompt_bin_counts=prompt_bin_counts,
|
|
||||||
output_bin_counts=output_bin_counts,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RequestState:
|
class RequestState:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -311,17 +243,6 @@ class RequestState:
|
|||||||
output_bin_counts=self.output_bin_counts,
|
output_bin_counts=self.output_bin_counts,
|
||||||
)
|
)
|
||||||
|
|
||||||
def expand_sampling_metadata(
|
|
||||||
self,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
cu_num_logits: torch.Tensor,
|
|
||||||
) -> SamplingMetadata:
|
|
||||||
# For draft tokens, we need to expand the sampling param tensors as
|
|
||||||
# each request samples multiple tokens in each step.
|
|
||||||
return expand_sampling_metadata(
|
|
||||||
sampling_metadata, cu_num_logits, self.num_speculative_steps
|
|
||||||
)
|
|
||||||
|
|
||||||
def make_lora_inputs(
|
def make_lora_inputs(
|
||||||
self,
|
self,
|
||||||
req_ids: list[str],
|
req_ids: list[str],
|
||||||
@@ -376,158 +297,9 @@ class UvaBuffer:
|
|||||||
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
|
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
|
||||||
|
|
||||||
|
|
||||||
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
|
|
||||||
@triton.jit
|
|
||||||
def _expand_sampling_metadata_kernel(
|
|
||||||
temp_ptr,
|
|
||||||
expanded_temp_ptr,
|
|
||||||
top_p_ptr,
|
|
||||||
expanded_top_p_ptr,
|
|
||||||
top_k_ptr,
|
|
||||||
expanded_top_k_ptr,
|
|
||||||
rep_penalty_ptr,
|
|
||||||
expanded_rep_penalty_ptr,
|
|
||||||
freq_penalty_ptr,
|
|
||||||
expanded_freq_penalty_ptr,
|
|
||||||
pres_penalty_ptr,
|
|
||||||
expanded_pres_penalty_ptr,
|
|
||||||
seeds_ptr,
|
|
||||||
expanded_seeds_ptr,
|
|
||||||
cu_num_logits_ptr,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
req_idx = tl.program_id(0)
|
|
||||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
|
||||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
|
||||||
num_tokens = end_idx - start_idx
|
|
||||||
|
|
||||||
block = tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = block < num_tokens
|
|
||||||
|
|
||||||
temp = tl.load(temp_ptr + req_idx)
|
|
||||||
tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
|
|
||||||
|
|
||||||
if top_p_ptr is not None:
|
|
||||||
top_p = tl.load(top_p_ptr + req_idx)
|
|
||||||
tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
|
|
||||||
|
|
||||||
if top_k_ptr is not None:
|
|
||||||
top_k = tl.load(top_k_ptr + req_idx)
|
|
||||||
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
|
|
||||||
|
|
||||||
rep_penalty = tl.load(rep_penalty_ptr + req_idx)
|
|
||||||
tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask)
|
|
||||||
|
|
||||||
freq_penalty = tl.load(freq_penalty_ptr + req_idx)
|
|
||||||
tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask)
|
|
||||||
|
|
||||||
pres_penalty = tl.load(pres_penalty_ptr + req_idx)
|
|
||||||
tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask)
|
|
||||||
|
|
||||||
seed = tl.load(seeds_ptr + req_idx)
|
|
||||||
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
def expand_sampling_metadata(
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
cu_num_logits: torch.Tensor,
|
|
||||||
num_speculative_steps: int,
|
|
||||||
) -> SamplingMetadata:
|
|
||||||
total_num_logits = sampling_metadata.pos.shape[0]
|
|
||||||
create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
|
|
||||||
expanded_temp = create_empty(sampling_metadata.temperature)
|
|
||||||
expanded_top_p = create_empty(sampling_metadata.top_p)
|
|
||||||
expanded_top_k = create_empty(sampling_metadata.top_k)
|
|
||||||
expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty)
|
|
||||||
expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty)
|
|
||||||
expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty)
|
|
||||||
expanded_seeds = create_empty(sampling_metadata.seeds)
|
|
||||||
|
|
||||||
num_reqs = cu_num_logits.shape[0] - 1
|
|
||||||
_expand_sampling_metadata_kernel[(num_reqs,)](
|
|
||||||
sampling_metadata.temperature,
|
|
||||||
expanded_temp,
|
|
||||||
sampling_metadata.top_p,
|
|
||||||
expanded_top_p,
|
|
||||||
sampling_metadata.top_k,
|
|
||||||
expanded_top_k,
|
|
||||||
sampling_metadata.repetition_penalty,
|
|
||||||
expanded_repetition_penalty,
|
|
||||||
sampling_metadata.frequency_penalty,
|
|
||||||
expanded_frequency_penalty,
|
|
||||||
sampling_metadata.presence_penalty,
|
|
||||||
expanded_presence_penalty,
|
|
||||||
sampling_metadata.seeds,
|
|
||||||
expanded_seeds,
|
|
||||||
cu_num_logits,
|
|
||||||
BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
|
|
||||||
)
|
|
||||||
return SamplingMetadata(
|
|
||||||
temperature=expanded_temp,
|
|
||||||
top_p=expanded_top_p,
|
|
||||||
top_k=expanded_top_k,
|
|
||||||
seeds=expanded_seeds,
|
|
||||||
repetition_penalty=expanded_repetition_penalty,
|
|
||||||
frequency_penalty=expanded_frequency_penalty,
|
|
||||||
presence_penalty=expanded_presence_penalty,
|
|
||||||
pos=sampling_metadata.pos,
|
|
||||||
max_num_logprobs=sampling_metadata.max_num_logprobs,
|
|
||||||
# TODO(woosuk): Support penalties with spec decoding.
|
|
||||||
idx_mapping=sampling_metadata.idx_mapping,
|
|
||||||
prompt_bin_counts=sampling_metadata.prompt_bin_counts,
|
|
||||||
output_bin_counts=sampling_metadata.output_bin_counts,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def use_penalty(sampling_params: SamplingParams) -> bool:
|
def use_penalty(sampling_params: SamplingParams) -> bool:
|
||||||
return (
|
return (
|
||||||
sampling_params.repetition_penalty != 1.0
|
sampling_params.repetition_penalty != 1.0
|
||||||
or sampling_params.frequency_penalty != 0.0
|
or sampling_params.frequency_penalty != 0.0
|
||||||
or sampling_params.presence_penalty != 0.0
|
or sampling_params.presence_penalty != 0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
|
|
||||||
def _bincount_kernel(
|
|
||||||
prefill_token_ids_ptr,
|
|
||||||
prefill_len,
|
|
||||||
prompt_len,
|
|
||||||
prompt_bin_counts_ptr,
|
|
||||||
output_bin_counts_ptr,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
block_idx = tl.program_id(0)
|
|
||||||
if block_idx * BLOCK_SIZE >= prefill_len:
|
|
||||||
return
|
|
||||||
|
|
||||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
||||||
if block_idx * BLOCK_SIZE < prompt_len:
|
|
||||||
mask = block < prompt_len
|
|
||||||
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
|
|
||||||
tl.atomic_add(prompt_bin_counts_ptr + prefill_tokens, 1, mask=mask)
|
|
||||||
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
|
|
||||||
mask = block < prefill_len
|
|
||||||
mask &= block >= prompt_len
|
|
||||||
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
|
|
||||||
tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
def bincount(
|
|
||||||
prefill_token_ids: torch.Tensor,
|
|
||||||
prefill_len: int,
|
|
||||||
prompt_len: int,
|
|
||||||
prompt_bin_counts: torch.Tensor,
|
|
||||||
output_bin_counts: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
prompt_bin_counts.zero_()
|
|
||||||
output_bin_counts.zero_()
|
|
||||||
BLOCK_SIZE = 1024
|
|
||||||
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
|
|
||||||
_bincount_kernel[(num_blocks,)](
|
|
||||||
prefill_token_ids,
|
|
||||||
prefill_len,
|
|
||||||
prompt_len,
|
|
||||||
prompt_bin_counts,
|
|
||||||
output_bin_counts,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user