E1: Wire LayerCacheHandle gather methods + CUDA gather kernels
- gather_compressed_kv: CSA top-k gather via existing gather_kv.cu - gather_all_compressed_kv: HCA dense gather via new gather_all_compressed_kernel - gather_swa_kv: SWA ring buffer gather via new gather_swa_kernel - Added gather_swa.cu with both SWA + all-compressed gather kernels - Added gather.py Python wrapper (torch.utils.cpp_extension JIT) - Updated handle.py: added schema field, num_query_heads/head_dim properties - Updated manager.py: passes schema + num_query_heads to handle All gather kernels: FP8→BF16 dequant + BF16 RoPE concat in single launch. Output: dense BF16 tensors ready for FMHA consumption.
This commit is contained in:
146
dsv4/cache/handle.py
vendored
146
dsv4/cache/handle.py
vendored
@@ -13,6 +13,7 @@ import torch
|
||||
if TYPE_CHECKING:
|
||||
from dsv4.cache.paged_cache import PagedKVPool
|
||||
from dsv4.cache.state_cache import StateCachePool
|
||||
from dsv4.cache.schema import LayerCacheSchema
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -25,6 +26,7 @@ class LayerCacheHandle:
|
||||
# Pool references (shared across handles — never mutated).
|
||||
paged: Optional["PagedKVPool"]
|
||||
state: "StateCachePool"
|
||||
schema: "LayerCacheSchema"
|
||||
|
||||
# Per-call indices.
|
||||
request_slots: torch.Tensor # [batch] int32 — state cache slot per request
|
||||
@@ -37,9 +39,30 @@ class LayerCacheHandle:
|
||||
# Number of valid blocks per request (excludes padding).
|
||||
block_lens: Optional[torch.Tensor]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Properties called by AttentionSubBlock
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def num_query_heads(self) -> int:
|
||||
"""Number of query heads (from schema)."""
|
||||
# The schema doesn't store n_q directly — derive from the config.
|
||||
# For now, store on the handle at construction.
|
||||
return self._num_query_heads
|
||||
|
||||
@num_query_heads.setter
|
||||
def num_query_heads(self, value: int):
|
||||
self._num_query_heads = value
|
||||
|
||||
@property
|
||||
def head_dim(self) -> int:
|
||||
"""Head dimension (from schema)."""
|
||||
return self.schema.entry_head_dim
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Methods called by AttentionSubBlock
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def write_swa(
|
||||
self,
|
||||
raw_kv: torch.Tensor, # (T, head_dim) BF16
|
||||
@@ -59,7 +82,7 @@ class LayerCacheHandle:
|
||||
swa_inv=self.state.swa_inv,
|
||||
swa_pos=self.state.swa_pos,
|
||||
swa_head=self.state.swa_head,
|
||||
rope_dim=self.state.schema.rope_dim,
|
||||
rope_dim=self.schema.rope_dim,
|
||||
)
|
||||
|
||||
def flush_compression(
|
||||
@@ -78,6 +101,122 @@ class LayerCacheHandle:
|
||||
"""
|
||||
raise NotImplementedError("see kernels/cache/flush_compression.py")
|
||||
|
||||
def gather_compressed_kv(
|
||||
self,
|
||||
selected_indices: torch.Tensor, # (T, top_k) int64 — from indexer
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""CSA: gather top-k compressed KV entries into dense BF16 tensors.
|
||||
|
||||
Returns:
|
||||
(k_compressed, v_compressed) each of shape (1, n_comp, head_dim) BF16.
|
||||
The leading dim=1 is for the single KV head (MQA in DSV4).
|
||||
"""
|
||||
assert self.paged is not None, "CSA gather requires paged pool"
|
||||
from dsv4.kernels.cache.gather import gather_compressed_kv
|
||||
|
||||
hd = self.head_dim
|
||||
rd = self.schema.rope_dim
|
||||
epb = self.schema.entries_per_block
|
||||
|
||||
# selected_indices is int64, gather kernel needs int32
|
||||
indices_i32 = selected_indices.to(torch.int32)
|
||||
|
||||
# block_table for CSA: [batch, max_logical_blocks]
|
||||
# For per-request gather, use the first request's block_table
|
||||
# (decode: batch=1, so this is trivial)
|
||||
if self.block_table.dim() == 1:
|
||||
bt = self.block_table.unsqueeze(0)
|
||||
else:
|
||||
bt = self.block_table
|
||||
|
||||
k_out = gather_compressed_kv(
|
||||
entries_fp8=self.paged.entries_fp8,
|
||||
entries_rope=self.paged.entries_rope,
|
||||
inv_scale=self.paged.inv_scale,
|
||||
topk_indices=indices_i32,
|
||||
block_table=bt,
|
||||
entries_per_block=epb,
|
||||
head_dim=hd,
|
||||
rope_dim=rd,
|
||||
)
|
||||
# k_out: (T, top_k, hd) — for FMHA we need (1, n_comp, hd)
|
||||
# At decode T=1: squeeze to (top_k, hd) then unsqueeze for KV head dim
|
||||
n_comp = k_out.shape[1]
|
||||
k_compressed = k_out.squeeze(0).unsqueeze(0) # (1, n_comp, hd)
|
||||
# V shares the same storage but is transposed — DSV4 uses K=V for
|
||||
# the compressed KV (same entries, different projection weights applied
|
||||
# before compression). For now, return the same gathered tensor.
|
||||
# TODO: verify if K and V are stored separately or shared.
|
||||
v_compressed = k_compressed.clone()
|
||||
return k_compressed, v_compressed
|
||||
|
||||
def gather_all_compressed_kv(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""HCA: gather ALL compressed KV entries into dense BF16 tensors.
|
||||
|
||||
No indexer — dense attention over the short compressed sequence.
|
||||
|
||||
Returns:
|
||||
(k_compressed, v_compressed) each of shape (1, n_comp, head_dim) BF16.
|
||||
"""
|
||||
assert self.paged is not None, "HCA gather requires paged pool"
|
||||
from dsv4.kernels.cache.gather import gather_all_compressed_kv
|
||||
|
||||
hd = self.head_dim
|
||||
rd = self.schema.rope_dim
|
||||
epb = self.schema.entries_per_block
|
||||
|
||||
if self.block_table.dim() == 1:
|
||||
bt = self.block_table.unsqueeze(0)
|
||||
bl = self.block_lens.unsqueeze(0) if self.block_lens is not None else None
|
||||
else:
|
||||
bt = self.block_table
|
||||
bl = self.block_lens
|
||||
|
||||
if bl is None:
|
||||
# Default: all blocks valid
|
||||
bl = torch.full((bt.shape[0],), bt.shape[1], dtype=torch.int32, device=bt.device)
|
||||
|
||||
k_out = gather_all_compressed_kv(
|
||||
entries_fp8=self.paged.entries_fp8,
|
||||
entries_rope=self.paged.entries_rope,
|
||||
inv_scale=self.paged.inv_scale,
|
||||
block_table=bt,
|
||||
block_lens=bl,
|
||||
entries_per_block=epb,
|
||||
head_dim=hd,
|
||||
rope_dim=rd,
|
||||
)
|
||||
# k_out: (batch, total_entries, hd) — for FMHA we need (1, n_comp, hd)
|
||||
n_comp = k_out.shape[1]
|
||||
k_compressed = k_out.squeeze(0).unsqueeze(0) # (1, n_comp, hd)
|
||||
v_compressed = k_compressed.clone()
|
||||
return k_compressed, v_compressed
|
||||
|
||||
def gather_swa_kv(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Gather SWA window entries into dense BF16 tensors.
|
||||
|
||||
Returns:
|
||||
(k_swa, v_swa) each of shape (1, swa_len, head_dim) BF16.
|
||||
"""
|
||||
from dsv4.kernels.cache.gather import gather_swa_kv
|
||||
|
||||
hd = self.head_dim
|
||||
rd = self.schema.rope_dim
|
||||
|
||||
k_out = gather_swa_kv(
|
||||
swa_fp8=self.state.swa_fp8,
|
||||
swa_rope=self.state.swa_rope,
|
||||
swa_inv=self.state.swa_inv,
|
||||
swa_pos=self.state.swa_pos,
|
||||
request_slots=self.request_slots,
|
||||
head_dim=hd,
|
||||
rope_dim=rd,
|
||||
)
|
||||
# k_out: (batch, n_win, hd) — for FMHA we need (1, swa_len, hd)
|
||||
k_swa = k_out.squeeze(0).unsqueeze(0) # (1, swa_len, hd)
|
||||
v_swa = k_swa.clone()
|
||||
return k_swa, v_swa
|
||||
|
||||
def read_swa_view(self) -> "SWAView":
|
||||
"""Return a typed view of the SWA window for this batch."""
|
||||
return SWAView(
|
||||
@@ -111,6 +250,11 @@ class LayerCacheHandle:
|
||||
block_lens=self.block_lens,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize _num_query_heads (must be set by the manager at construction)
|
||||
if not hasattr(self, '_num_query_heads'):
|
||||
self._num_query_heads = 0
|
||||
|
||||
|
||||
# Typed views — simple dataclasses, no logic. The FMHA / indexer / SWA
|
||||
# kernels accept these to keep their signatures clean.
|
||||
|
||||
5
dsv4/cache/manager.py
vendored
5
dsv4/cache/manager.py
vendored
@@ -173,15 +173,18 @@ class KVCacheManager:
|
||||
block_table = None
|
||||
block_lens = None
|
||||
|
||||
return LayerCacheHandle(
|
||||
handle = LayerCacheHandle(
|
||||
paged=paged,
|
||||
state=state,
|
||||
schema=self.schemas[layer_idx],
|
||||
request_slots=request_slots,
|
||||
positions=positions,
|
||||
request_ids=request_ids,
|
||||
block_table=block_table,
|
||||
block_lens=block_lens,
|
||||
)
|
||||
handle.num_query_heads = self.config.num_query_heads
|
||||
return handle
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Diagnostics
|
||||
|
||||
115
dsv4/kernels/cache/gather.py
vendored
Normal file
115
dsv4/kernels/cache/gather.py
vendored
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Python wrappers for cache gather kernels.
|
||||
|
||||
Provides the bridge between LayerCacheHandle and the raw CUDA gather ops.
|
||||
Loaded via torch.utils.cpp_extension (JIT, sm_100a).
|
||||
|
||||
Three gather paths:
|
||||
1. gather_compressed_kv: CSA — top-k selected compressed entries
|
||||
2. gather_all_compressed_kv: HCA — all compressed entries (no indexer)
|
||||
3. gather_swa_kv: SWA window entries from state cache
|
||||
|
||||
All outputs are dense BF16 tensors ready for FMHA consumption.
|
||||
"""
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_kernel_module = None
|
||||
_kernel_dir = os.path.join(os.path.dirname(__file__), "cuda")
|
||||
|
||||
|
||||
def _get_kernel_module():
|
||||
global _kernel_module
|
||||
if _kernel_module is not None:
|
||||
return _kernel_module
|
||||
_kernel_module = load(
|
||||
name="cache_gather",
|
||||
sources=[
|
||||
os.path.join(_kernel_dir, "gather_kv.cu"),
|
||||
os.path.join(_kernel_dir, "gather_swa.cu"),
|
||||
],
|
||||
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
|
||||
verbose=False,
|
||||
)
|
||||
return _kernel_module
|
||||
|
||||
|
||||
def gather_compressed_kv(
|
||||
entries_fp8: torch.Tensor, # [num_blocks, epb, fp8_dim] uint8
|
||||
entries_rope: torch.Tensor, # [num_blocks, epb, rope_dim] BF16
|
||||
inv_scale: torch.Tensor, # [num_blocks, epb] FP32
|
||||
topk_indices: torch.Tensor, # [T, top_k] int32
|
||||
block_table: torch.Tensor, # [batch, max_logical_blocks] int32
|
||||
entries_per_block: int,
|
||||
head_dim: int,
|
||||
rope_dim: int,
|
||||
) -> torch.Tensor:
|
||||
"""Gather top-k compressed KV entries into a dense BF16 tile.
|
||||
|
||||
Returns: (T, top_k, head_dim) BF16
|
||||
"""
|
||||
T = topk_indices.size(0)
|
||||
top_k = topk_indices.size(1)
|
||||
output = torch.zeros(T, top_k, head_dim, dtype=torch.bfloat16, device=entries_fp8.device)
|
||||
|
||||
mod = _get_kernel_module()
|
||||
mod.gather_kv(
|
||||
entries_fp8, entries_rope, inv_scale,
|
||||
topk_indices, block_table, output,
|
||||
entries_per_block, rope_dim,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def gather_all_compressed_kv(
|
||||
entries_fp8: torch.Tensor, # [num_blocks, epb, fp8_dim] uint8
|
||||
entries_rope: torch.Tensor, # [num_blocks, epb, rope_dim] BF16
|
||||
inv_scale: torch.Tensor, # [num_blocks, epb] FP32
|
||||
block_table: torch.Tensor, # [batch, max_logical_blocks] int32
|
||||
block_lens: torch.Tensor, # [batch] int32
|
||||
entries_per_block: int,
|
||||
head_dim: int,
|
||||
rope_dim: int,
|
||||
) -> torch.Tensor:
|
||||
"""Gather ALL compressed KV entries for HCA (dense attention).
|
||||
|
||||
Returns: (batch, total_entries, head_dim) BF16 where
|
||||
total_entries = block_lens.sum() * entries_per_block (padded to max)
|
||||
"""
|
||||
batch = block_table.size(0)
|
||||
max_blocks = block_lens.max().item()
|
||||
total_entries = max_blocks * entries_per_block
|
||||
output = torch.zeros(batch, total_entries, head_dim, dtype=torch.bfloat16, device=entries_fp8.device)
|
||||
|
||||
mod = _get_kernel_module()
|
||||
mod.gather_all_compressed(
|
||||
entries_fp8, entries_rope, inv_scale,
|
||||
block_table, block_lens, output,
|
||||
entries_per_block, rope_dim,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def gather_swa_kv(
|
||||
swa_fp8: torch.Tensor, # [max_req, n_win, fp8_dim] uint8
|
||||
swa_rope: torch.Tensor, # [max_req, n_win, rope_dim] BF16
|
||||
swa_inv: torch.Tensor, # [max_req, n_win] FP32
|
||||
swa_pos: torch.Tensor, # [max_req, n_win] int32
|
||||
request_slots: torch.Tensor, # [batch] int32
|
||||
head_dim: int,
|
||||
rope_dim: int,
|
||||
) -> torch.Tensor:
|
||||
"""Gather SWA window entries into a dense BF16 tile.
|
||||
|
||||
Returns: (batch, n_win, head_dim) BF16
|
||||
"""
|
||||
batch = request_slots.size(0)
|
||||
n_win = swa_fp8.size(1)
|
||||
output = torch.zeros(batch, n_win, head_dim, dtype=torch.bfloat16, device=swa_fp8.device)
|
||||
|
||||
mod = _get_kernel_module()
|
||||
mod.gather_swa(
|
||||
swa_fp8, swa_rope, swa_inv, swa_pos, request_slots,
|
||||
output, head_dim, rope_dim,
|
||||
)
|
||||
return output
|
||||
177
dsv4/kernels/cuda/gather_swa.cu
Normal file
177
dsv4/kernels/cuda/gather_swa.cu
Normal file
@@ -0,0 +1,177 @@
|
||||
// gather_swa.cu — Gather SWA window entries into a dense BF16 tile.
|
||||
//
|
||||
// Reads from the state cache's SWA ring buffer (FP8 + BF16 split layout).
|
||||
// One CTA per request. Each CTA iterates over the n_win positions in the
|
||||
// ring buffer, dequantizes FP8 → BF16, concatenates the RoPE half, and
|
||||
// writes to the dense output tensor.
|
||||
//
|
||||
// Output shape: [batch, n_win, head_dim] BF16 — consumed by FMHA.
|
||||
// Positions with swa_pos == -1 (unused slots) are zero-filled.
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
|
||||
|
||||
__global__ void gather_swa_kernel(
|
||||
// Inputs
|
||||
const uint8_t* __restrict__ swa_fp8, // [max_req, n_win, fp8_dim]
|
||||
const __nv_bfloat16* __restrict__ swa_rope, // [max_req, n_win, rope_dim]
|
||||
const float* __restrict__ swa_inv, // [max_req, n_win]
|
||||
const int32_t* __restrict__ swa_pos, // [max_req, n_win]
|
||||
const int32_t* __restrict__ request_slots, // [batch] — state cache slot per request
|
||||
// Output
|
||||
__nv_bfloat16* __restrict__ output, // [batch, n_win, head_dim] BF16
|
||||
// Geometry
|
||||
int batch, int n_win, int head_dim, int rope_dim, int max_requests
|
||||
) {
|
||||
int fp8_dim = head_dim - rope_dim;
|
||||
int b = blockIdx.x;
|
||||
if (b >= batch) return;
|
||||
|
||||
int slot = request_slots[b];
|
||||
|
||||
for (int w = 0; w < n_win; w++) {
|
||||
int pos = swa_pos[slot * n_win + w];
|
||||
int out_row = b * n_win * head_dim + w * head_dim;
|
||||
|
||||
if (pos < 0) {
|
||||
// Unused slot — zero fill
|
||||
for (int d = threadIdx.x; d < head_dim; d += blockDim.x) {
|
||||
output[out_row + d] = __float2bfloat16(0.0f);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
float s = swa_inv[slot * n_win + w];
|
||||
|
||||
// Dequantize FP8 half
|
||||
for (int d = threadIdx.x; d < fp8_dim; d += blockDim.x) {
|
||||
uint8_t raw = swa_fp8[(slot * n_win + w) * fp8_dim + d];
|
||||
__nv_fp8_e4m3 fp8_val;
|
||||
fp8_val.__x = raw;
|
||||
float dequant = (float)fp8_val * s;
|
||||
output[out_row + d] = __float2bfloat16(dequant);
|
||||
}
|
||||
|
||||
// Copy BF16 RoPE half
|
||||
for (int d = threadIdx.x; d < rope_dim; d += blockDim.x) {
|
||||
output[out_row + fp8_dim + d] = swa_rope[(slot * n_win + w) * rope_dim + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void gather_swa_cuda(
|
||||
torch::Tensor swa_fp8,
|
||||
torch::Tensor swa_rope,
|
||||
torch::Tensor swa_inv,
|
||||
torch::Tensor swa_pos,
|
||||
torch::Tensor request_slots,
|
||||
torch::Tensor output,
|
||||
int64_t head_dim, int64_t rope_dim
|
||||
) {
|
||||
int batch = request_slots.size(0);
|
||||
int n_win = swa_fp8.size(1);
|
||||
int max_requests = swa_fp8.size(0);
|
||||
|
||||
int threads = 128;
|
||||
gather_swa_kernel<<<batch, threads>>>(
|
||||
swa_fp8.data_ptr<uint8_t>(),
|
||||
reinterpret_cast<const __nv_bfloat16*>(swa_rope.data_ptr<at::BFloat16>()),
|
||||
swa_inv.data_ptr<float>(),
|
||||
swa_pos.data_ptr<int32_t>(),
|
||||
request_slots.data_ptr<int32_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>()),
|
||||
batch, n_win, (int)head_dim, (int)rope_dim, max_requests
|
||||
);
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
// gather_all_compressed_kv_kernel — Gather ALL compressed entries for HCA
|
||||
// (no top-k, dense attention over the short compressed sequence).
|
||||
//
|
||||
// One CTA per request. Iterates over all valid blocks in the block_table,
|
||||
// dequantizes all entries, writes to a dense output.
|
||||
// Output: [batch, total_entries, head_dim] BF16
|
||||
|
||||
__global__ void gather_all_compressed_kernel(
|
||||
const uint8_t* __restrict__ entries_fp8,
|
||||
const __nv_bfloat16* __restrict__ entries_rope,
|
||||
const float* __restrict__ inv_scale,
|
||||
const int32_t* __restrict__ block_table, // [batch, max_logical_blocks]
|
||||
const int32_t* __restrict__ block_lens, // [batch] — valid blocks per request
|
||||
__nv_bfloat16* __restrict__ output, // [batch, total_entries, head_dim]
|
||||
int batch, int entries_per_block, int head_dim,
|
||||
int rope_dim, int max_logical_blocks
|
||||
) {
|
||||
int fp8_dim = head_dim - rope_dim;
|
||||
int b = blockIdx.x;
|
||||
if (b >= batch) return;
|
||||
|
||||
int n_blocks = block_lens[b];
|
||||
int out_idx = 0;
|
||||
|
||||
for (int lb = 0; lb < n_blocks; lb++) {
|
||||
int phys_block = block_table[b * max_logical_blocks + lb];
|
||||
if (phys_block < 0) continue;
|
||||
|
||||
for (int epb = 0; epb < entries_per_block; epb++) {
|
||||
int block_entry = phys_block * entries_per_block + epb;
|
||||
float s = inv_scale[block_entry];
|
||||
int out_row = (b * n_blocks * entries_per_block + out_idx) * head_dim;
|
||||
|
||||
// Dequantize FP8 half
|
||||
for (int d = threadIdx.x; d < fp8_dim; d += blockDim.x) {
|
||||
uint8_t raw = entries_fp8[block_entry * fp8_dim + d];
|
||||
__nv_fp8_e4m3 fp8_val;
|
||||
fp8_val.__x = raw;
|
||||
float dequant = (float)fp8_val * s;
|
||||
output[out_row + d] = __float2bfloat16(dequant);
|
||||
}
|
||||
|
||||
// Copy BF16 RoPE half
|
||||
for (int d = threadIdx.x; d < rope_dim; d += blockDim.x) {
|
||||
output[out_row + fp8_dim + d] = entries_rope[block_entry * rope_dim + d];
|
||||
}
|
||||
out_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void gather_all_compressed_cuda(
|
||||
torch::Tensor entries_fp8,
|
||||
torch::Tensor entries_rope,
|
||||
torch::Tensor inv_scale,
|
||||
torch::Tensor block_table,
|
||||
torch::Tensor block_lens,
|
||||
torch::Tensor output,
|
||||
int64_t entries_per_block, int64_t rope_dim
|
||||
) {
|
||||
int batch = block_table.size(0);
|
||||
int head_dim = entries_fp8.size(2) + entries_rope.size(2);
|
||||
int max_logical_blocks = block_table.size(1);
|
||||
|
||||
int threads = 128;
|
||||
gather_all_compressed_kernel<<<batch, threads>>>(
|
||||
entries_fp8.data_ptr<uint8_t>(),
|
||||
reinterpret_cast<const __nv_bfloat16*>(entries_rope.data_ptr<at::BFloat16>()),
|
||||
inv_scale.data_ptr<float>(),
|
||||
block_table.data_ptr<int32_t>(),
|
||||
block_lens.data_ptr<int32_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>()),
|
||||
batch, (int)entries_per_block, (int)head_dim,
|
||||
(int)rope_dim, max_logical_blocks
|
||||
);
|
||||
C10_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("gather_swa", &gather_swa_cuda, "Gather SWA window into dense BF16 tile");
|
||||
m.def("gather_all_compressed", &gather_all_compressed_cuda, "Gather all compressed KV for HCA");
|
||||
}
|
||||
Reference in New Issue
Block a user