[Bugfix] Fix quant RMS norm fusion for quantization with TMA-aligned scales (#33255)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -379,7 +379,9 @@ void per_token_group_quant_8bit_packed(const torch::Tensor& input,
|
||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
||||
int64_t group_size, double eps, double fp8_min,
|
||||
double fp8_max, bool scale_ue8m0) {
|
||||
double fp8_max, bool scale_ue8m0,
|
||||
bool dummy_is_scale_transposed = false,
|
||||
bool dummy_is_tma_aligned = false) {
|
||||
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
||||
fp8_min, fp8_max, scale_ue8m0);
|
||||
}
|
||||
Reference in New Issue
Block a user