[BugFix] Fix assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1] in Blackwell Quantized MoE Test (#32362)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -247,8 +247,8 @@ def scaled_dequantize(
|
||||
if group_shape is not None:
|
||||
group_shape = _normalize_quant_group_shape(x_q, group_shape)
|
||||
|
||||
if x_s.ndim == 0: # scalar
|
||||
x_s = x_s.unsqueeze(-1).unsqueeze(-1) # convert to (1, 1) tensor
|
||||
if x_s.numel() == 1: # scalar
|
||||
x_s = x_s.reshape(1, 1) # normalize all scalar-like tensors to (1, 1)
|
||||
if x_s.ndim == 1:
|
||||
if group_shape is None:
|
||||
raise AssertionError(
|
||||
|
||||
Reference in New Issue
Block a user