[Bugfix] opcheck false mutation error in rms_norm_per_block_quant (#36688) (#36779)

Signed-off-by: Krish Gupta <krishom70@gmail.com>
This commit is contained in:
Krish Gupta
2026-03-17 02:41:33 +05:30
committed by GitHub
parent e6ae4b1be1
commit c0f011918d
2 changed files with 19 additions and 9 deletions

View File

@@ -286,6 +286,15 @@ void rms_norm_per_block_quant(torch::Tensor& out, torch::Tensor const& input,
"Outer scale stride must be 1 when scales are not transposed");
}
int64_t hidden_size = input.size(-1);
TORCH_CHECK(hidden_size > 0 && hidden_size % group_size == 0,
"hidden_size must be a positive multiple of group_size");
int64_t num_tokens = input.numel() / hidden_size;
int64_t num_groups = hidden_size / group_size;
TORCH_CHECK(scales.numel() >= num_tokens * num_groups,
"scales buffer too small: need ", num_tokens * num_groups,
" elements, got ", scales.numel());
rms_norm_per_block_quant_dispatch(out, input, weight, scales, group_size,
var_epsilon, scale_ub, residual,
is_scale_transposed);

View File

@@ -280,21 +280,22 @@ def test_rms_norm(
assert torch.allclose(ref_residual, ops_residual)
output = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
scales = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
if group_size is None:
scales = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
opcheck(
torch.ops._C.rms_norm_dynamic_per_token_quant,
(output, x, layer.weight, scales, 1e-5, scale_ub, residual),
)
else:
# TODO(luka/eliza) opcheck is broken?
# Somehow the cloned args are getting mutated in-place,
# which causes the opcheck to fail.
# https://github.com/vllm-project/vllm/issues/36688
return
assert hidden_size % group_size[1] == 0
num_groups = hidden_size // group_size[1]
scales = torch.empty(
(num_groups, num_tokens),
device=x.device,
dtype=torch.float32,
).transpose(0, 1)
opcheck(
torch.ops._C.rms_norm_per_block_quant,
(