[BugFix] Fix assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1] in Blackwell Quantized MoE Test (#32362)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Lucas Wilkinson
2026-01-15 11:19:12 -07:00
committed by GitHub
parent 047413375c
commit c36ba69bda

View File

@@ -247,8 +247,8 @@ def scaled_dequantize(
if group_shape is not None:
group_shape = _normalize_quant_group_shape(x_q, group_shape)
if x_s.ndim == 0: # scalar
x_s = x_s.unsqueeze(-1).unsqueeze(-1) # convert to (1, 1) tensor
if x_s.numel() == 1: # scalar
x_s = x_s.reshape(1, 1) # normalize all scalar-like tensors to (1, 1)
if x_s.ndim == 1:
if group_shape is None:
raise AssertionError(