Signed-off-by: Krish Gupta <krishom70@gmail.com>
This commit is contained in:
@@ -286,6 +286,15 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
"Outer scale stride must be 1 when scales are not transposed");
|
||||
}
|
||||
|
||||
int64_t hidden_size = input.size(-1);
|
||||
TORCH_CHECK(hidden_size > 0 && hidden_size % group_size == 0,
|
||||
"hidden_size must be a positive multiple of group_size");
|
||||
int64_t num_tokens = input.numel() / hidden_size;
|
||||
int64_t num_groups = hidden_size / group_size;
|
||||
TORCH_CHECK(scales.numel() >= num_tokens * num_groups,
|
||||
"scales buffer too small: need ", num_tokens * num_groups,
|
||||
" elements, got ", scales.numel());
|
||||
|
||||
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
|
||||
var_epsilon, scale_ub, residual,
|
||||
is_scale_transposed);
|
||||
|
||||
@@ -280,21 +280,22 @@ def test_rms_norm(
|
||||
assert torch.allclose(ref_residual, ops_residual)
|
||||
|
||||
output = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||
scales = torch.empty(
|
||||
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
if group_size is None:
|
||||
scales = torch.empty(
|
||||
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
|
||||
)
|
||||
opcheck(
|
||||
torch.ops._C.rms_norm_dynamic_per_token_quant,
|
||||
(output, x, layer.weight, scales, 1e-5, scale_ub, residual),
|
||||
)
|
||||
else:
|
||||
# TODO(luka/eliza) opcheck is broken?
|
||||
# Somehow the cloned args are getting mutated in-place,
|
||||
# which causes the opcheck to fail.
|
||||
# https://github.com/vllm-project/vllm/issues/36688
|
||||
return
|
||||
assert hidden_size % group_size[1] == 0
|
||||
num_groups = hidden_size // group_size[1]
|
||||
scales = torch.empty(
|
||||
(num_groups, num_tokens),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
).transpose(0, 1)
|
||||
opcheck(
|
||||
torch.ops._C.rms_norm_per_block_quant,
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user