[Bugfix] Fix quant RMS norm fusion for quantization with TMA-aligned scales (#33255)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
ElizaWszola
2026-02-18 08:35:04 +01:00
committed by GitHub
parent a49ea5a58f
commit a88b3be7c4
12 changed files with 234 additions and 75 deletions

View File

@@ -379,7 +379,9 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
double fp8_max, bool scale_ue8m0) {
double fp8_max, bool scale_ue8m0,
bool dummy_is_scale_transposed = false,
bool dummy_is_tma_aligned = false) {
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
fp8_min, fp8_max, scale_ue8m0);
}