[Bugfix] opcheck false mutation error in rms_norm_per_block_quant (#36688) (#36779)

Signed-off-by: Krish Gupta <krishom70@gmail.com>
This commit is contained in:
Krish Gupta
2026-03-17 02:41:33 +05:30
committed by GitHub
parent e6ae4b1be1
commit c0f011918d
2 changed files with 19 additions and 9 deletions

View File

@@ -286,6 +286,15 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
"Outer scale stride must be 1 when scales are not transposed");
}
int64_t hidden_size = input.size(-1);
TORCH_CHECK(hidden_size > 0 && hidden_size % group_size == 0,
"hidden_size must be a positive multiple of group_size");
int64_t num_tokens = input.numel() / hidden_size;
int64_t num_groups = hidden_size / group_size;
TORCH_CHECK(scales.numel() >= num_tokens * num_groups,
"scales buffer too small: need ", num_tokens * num_groups,
" elements, got ", scales.numel());
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
var_epsilon, scale_ub, residual,
is_scale_transposed);