[perf] Add fused MLA QKV + strided layernorm (#21116)

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Mickaël Seznec
2025-07-22 16:07:44 +02:00
committed by GitHub
parent 0df4d9b06b
commit 4fb56914c5
7 changed files with 214 additions and 66 deletions

View File

@@ -88,6 +88,8 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel();
@@ -111,6 +113,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel();