[ROCm] Add skinny gemm bias support for dtypes fp16,bf16,fp8 (#24988)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com>
This commit is contained in:
Hashem Hashemi
2025-09-23 11:31:45 -07:00
committed by GitHub
parent 5abb117901
commit a3a7828010
7 changed files with 231 additions and 77 deletions

View File

@@ -5,11 +5,14 @@
torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
const int64_t rows_per_block);
torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias,
const int64_t CuCount);
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b,
const c10::optional<at::Tensor>& in_bias, at::Tensor& out_c,
const at::Tensor& scale_a, const at::Tensor& scale_b,
const int64_t CuCount);
void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,