[Kernel] Update cutlass_scaled_mm to support 2d group (blockwise) scaling (#11868)

This commit is contained in:
Lucas Wilkinson
2025-01-30 21:33:00 -05:00
committed by GitHub
parent 4078052f09
commit 9798b2fb00
25 changed files with 1924 additions and 346 deletions

View File

@@ -0,0 +1,33 @@
#pragma once
#include <torch/all.h>
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
} // namespace vllm