Signed-off-by: Krish Gupta <krishom70@gmail.com>
This commit is contained in:
@@ -286,6 +286,15 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
"Outer scale stride must be 1 when scales are not transposed");
|
||||
}
|
||||
|
||||
int64_t hidden_size = input.size(-1);
|
||||
TORCH_CHECK(hidden_size > 0 && hidden_size % group_size == 0,
|
||||
"hidden_size must be a positive multiple of group_size");
|
||||
int64_t num_tokens = input.numel() / hidden_size;
|
||||
int64_t num_groups = hidden_size / group_size;
|
||||
TORCH_CHECK(scales.numel() >= num_tokens * num_groups,
|
||||
"scales buffer too small: need ", num_tokens * num_groups,
|
||||
" elements, got ", scales.numel());
|
||||
|
||||
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
|
||||
var_epsilon, scale_ub, residual,
|
||||
is_scale_transposed);
|
||||
|
||||
Reference in New Issue
Block a user