202 lines
7.2 KiB
Plaintext
202 lines
7.2 KiB
Plaintext
/**
|
|
* Production fused sampler kernel for DSV4 inference.
|
|
*
|
|
* Fused: repetition penalty → temperature → top-k → top-p (nucleus) → sample.
|
|
* Single kernel launch, zero CPU syncs, CUDA-graph-compatible.
|
|
*
|
|
* Architecture:
|
|
* - 1 CUDA block per batch item
|
|
* - 256 threads per block
|
|
* - Each thread scans its slice of the vocab, applies penalty + temperature,
|
|
* and tracks the top-k candidates using a sorted array in registers
|
|
* - Thread 0 merges all 256 per-thread top-k lists into a global top-k
|
|
* - Thread 0 computes softmax over top-k, applies top-p, and samples
|
|
*
|
|
* SMEM: 256 * LOCAL_K * 8 bytes (scores + indices)
|
|
* = 256 * 32 * 8 = 64KB for LOCAL_K=32
|
|
* Each thread tracks top-32; the merge considers 256*32=8192 candidates,
|
|
* yielding an effective top-k of up to 256 (more than enough for any
|
|
* practical use case).
|
|
*
|
|
* Repetition penalty: passed as (max_penalty, batch, 2) where [:, :, 0] = token_id
|
|
* and [:, :, 1] = penalty_value (multiplicative: >1.0 penalizes, <1.0 boosts).
|
|
* The penalty is applied as: if logit > 0, logit /= penalty; else logit *= penalty.
|
|
* This matches the HuggingFace generate() convention.
|
|
*/
|
|
|
|
#include <cuda.h>
|
|
#include <cuda_runtime.h>
|
|
#include <ATen/ATen.h>
|
|
#include <c10/cuda/CUDAStream.h>
|
|
#include <torch/extension.h>
|
|
#include <cstdint>
|
|
#include <cfloat>
|
|
#include <curand_kernel.h>
|
|
|
|
static constexpr int BDIM = 256;
|
|
static constexpr int LK = 24; // per-thread local top-k (SMEM budget: 256*24*8=48KB fits default)
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Insert into sorted descending array (register-resident, k small)
|
|
// ---------------------------------------------------------------------------
|
|
__device__ void sorted_insert(float* sc, int* idx, int k, int& n, float s, int i) {
|
|
if (n < k) {
|
|
int p = n;
|
|
while (p > 0 && s > sc[p-1]) { sc[p] = sc[p-1]; idx[p] = idx[p-1]; p--; }
|
|
sc[p] = s; idx[p] = i; n++;
|
|
} else if (s > sc[k-1]) {
|
|
int p = k-1; sc[p] = s; idx[p] = i;
|
|
while (p > 0 && sc[p] > sc[p-1]) {
|
|
float ts=sc[p]; int ti=idx[p]; sc[p]=sc[p-1]; idx[p]=idx[p-1]; sc[p-1]=ts; idx[p-1]=ti; p--;
|
|
}
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Kernel
|
|
// ---------------------------------------------------------------------------
|
|
__global__ void fused_sampler_kernel(
|
|
const float* __restrict__ logits, // (B, V) stride=vs
|
|
const int64_t* __restrict__ pen_ids, // (B, max_pen) or nullptr
|
|
const float* __restrict__ pen_vals, // (B, max_pen) or nullptr
|
|
int B, int V, int vs, int max_pen,
|
|
float temp, int top_k, float top_p, int min_keep,
|
|
uint64_t seed, uint64_t offset,
|
|
int64_t* __restrict__ out_ids // (B,)
|
|
) {
|
|
int b = blockIdx.x;
|
|
if (b >= B) return;
|
|
int tid = threadIdx.x;
|
|
const float* row = logits + b * vs;
|
|
|
|
// ---------- Phase 1: per-thread top-LK ----------
|
|
float lsc[LK]; int lid[LK]; int ln = 0;
|
|
|
|
for (int v = tid; v < V; v += BDIM) {
|
|
float val = row[v];
|
|
// Repetition penalty
|
|
if (pen_ids) {
|
|
auto brow = pen_ids + b * max_pen;
|
|
auto vrow = pen_vals + b * max_pen;
|
|
for (int p = 0; p < max_pen; p++) {
|
|
if (brow[p] == v) {
|
|
val = (val > 0.0f) ? val / vrow[p] : val * vrow[p];
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
val /= temp;
|
|
sorted_insert(lsc, lid, LK, ln, val, v);
|
|
}
|
|
|
|
// ---------- Phase 2: write to SMEM, thread 0 merges ----------
|
|
extern __shared__ char smem[];
|
|
float* s_sc = reinterpret_cast<float*>(smem);
|
|
int* s_idx = reinterpret_cast<int*>(smem + BDIM * LK * sizeof(float));
|
|
|
|
for (int i = 0; i < ln; i++) { s_sc[tid*LK+i] = lsc[i]; s_idx[tid*LK+i] = lid[i]; }
|
|
for (int i = ln; i < LK; i++) { s_sc[tid*LK+i] = -FLT_MAX; s_idx[tid*LK+i] = 0; }
|
|
__syncthreads();
|
|
|
|
if (tid == 0) {
|
|
// Merge: find global top-k from BDIM * LK = 8192 candidates
|
|
int eff_k = min(top_k, 128); // kernel max (stack limit: 128 * 8 = 1KB)
|
|
if (eff_k <= 0) eff_k = 128;
|
|
|
|
float gsc[128]; int gid[128]; int gn = 0;
|
|
for (int t = 0; t < BDIM; t++) {
|
|
for (int i = 0; i < LK; i++) {
|
|
float s = s_sc[t*LK+i];
|
|
if (s <= -FLT_MAX + 1.0f) continue;
|
|
sorted_insert(gsc, gid, eff_k, gn, s, s_idx[t*LK+i]);
|
|
}
|
|
}
|
|
|
|
if (gn == 0) { out_ids[b] = 0; return; }
|
|
|
|
// ---------- Phase 3: softmax + top-p + sample ----------
|
|
float mx = gsc[0]; // sorted desc, first is max
|
|
float probs[128]; float total = 0.0f;
|
|
for (int i = 0; i < gn; i++) {
|
|
probs[i] = expf(gsc[i] - mx);
|
|
total += probs[i];
|
|
}
|
|
|
|
// Top-p
|
|
int nk = gn;
|
|
if (top_p < 1.0f) {
|
|
float cs = 0.0f;
|
|
for (int i = 0; i < gn; i++) {
|
|
cs += probs[i];
|
|
if (cs / total >= top_p) { nk = max(i+1, min_keep); break; }
|
|
}
|
|
}
|
|
|
|
// Renormalize
|
|
float kt = 0.0f;
|
|
for (int i = 0; i < nk; i++) kt += probs[i];
|
|
|
|
// Sample
|
|
curandState rng;
|
|
curand_init(seed, b, offset, &rng);
|
|
float r = curand_uniform(&rng) * kt;
|
|
float acc = 0.0f;
|
|
int sel = nk - 1;
|
|
for (int i = 0; i < nk; i++) {
|
|
acc += probs[i];
|
|
if (acc >= r) { sel = i; break; }
|
|
}
|
|
out_ids[b] = gid[sel];
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Binding
|
|
// ---------------------------------------------------------------------------
|
|
torch::Tensor sample_cuda(
|
|
torch::Tensor logits,
|
|
std::optional<torch::Tensor> pen_ids,
|
|
std::optional<torch::Tensor> pen_vals,
|
|
double temperature,
|
|
int64_t top_k,
|
|
double top_p,
|
|
int64_t min_keep,
|
|
int64_t seed,
|
|
int64_t offset
|
|
) {
|
|
TORCH_CHECK(logits.is_contiguous() && logits.dim() == 2 && logits.scalar_type() == torch::kFloat32);
|
|
int B = logits.size(0), V = logits.size(1);
|
|
int mp = 0; const int64_t* pi = nullptr; const float* pv = nullptr;
|
|
if (pen_ids && pen_ids->numel()) { mp = pen_ids->size(1); pi = pen_ids->data_ptr<int64_t>(); pv = pen_vals->data_ptr<float>(); }
|
|
|
|
auto options = logits.options().dtype(torch::kInt64);
|
|
auto out = torch::empty({B}, options);
|
|
int smem = BDIM * LK * (sizeof(float) + sizeof(int));
|
|
|
|
// Request enough shared memory for 48KB+ per block
|
|
cudaFuncSetAttribute(
|
|
fused_sampler_kernel,
|
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
|
smem
|
|
);
|
|
// Carveout: prefer more shared memory over L1
|
|
cudaFuncSetAttribute(
|
|
fused_sampler_kernel,
|
|
cudaFuncAttributePreferredSharedMemoryCarveout,
|
|
cudaSharedmemCarveoutMaxShared
|
|
);
|
|
|
|
fused_sampler_kernel<<<B, BDIM, smem, c10::cuda::getCurrentCUDAStream()>>>(
|
|
logits.data_ptr<float>(), pi, pv,
|
|
B, V, logits.stride(0), mp,
|
|
(float)temperature, (int)top_k, (float)top_p, (int)min_keep,
|
|
(uint64_t)seed, (uint64_t)offset,
|
|
out.data_ptr<int64_t>()
|
|
);
|
|
return out;
|
|
}
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("sample", &sample_cuda, "Fused top-k/top-p sampler");
|
|
}
|