From 9fec7d609e4b919e4151707797fd0906e0784f27 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 21:33:59 +0000 Subject: [PATCH] Fix gsa_buffer shape mismatch for MoE (M>1 rows) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit compute_amax_gsa returns a scalar, but quantize_from_buffer expects (M,). Broadcast the scalar gsa to (M,) — all rows use the same gsa (global max). --- dsv4/ops/quantize.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/dsv4/ops/quantize.py b/dsv4/ops/quantize.py index 5ebef74d..72b7f2ae 100644 --- a/dsv4/ops/quantize.py +++ b/dsv4/ops/quantize.py @@ -269,8 +269,11 @@ def deinterleave_amax_quantize_nvfp4_fused(fused_bf16, intermediate, divisor=6.0 # Compute gsa from the fused output amax_mod = get_cuda_module("amax_gsa", ["amax_gsa.cu"]) gsa_gpu = amax_mod.compute_amax_gsa(fused_bf16, divisor) + M = fused_bf16.shape[0] if gsa_gpu.dim() == 0: - gsa_gpu = gsa_gpu.reshape(1) + gsa_gpu = gsa_gpu.reshape(1).expand(M).contiguous() + elif gsa_gpu.shape[0] == 1 and M > 1: + gsa_gpu = gsa_gpu.expand(M).contiguous() # Deinterleave + quantize using gsa from GPU buffer quant_mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"]) x_fp4, x_sf = quant_mod.deinterleave_quantize_from_buffer(fused_bf16, intermediate, granularity, gsa_gpu) @@ -314,10 +317,13 @@ def quantize_nvfp4_gpu_fused(x_bf16, divisor=6.0 * 448.0): """ from dsv4.kernels.cuda.loader import get_cuda_module amax_mod = get_cuda_module("amax_gsa", ["amax_gsa.cu"]) - gsa_gpu = amax_mod.compute_amax_gsa(x_bf16, divisor) # scalar GPU tensor for M=1 - # Reshape to (M,) for the quantize-from-buffer kernel + gsa_gpu = amax_mod.compute_amax_gsa(x_bf16, divisor) # scalar GPU tensor + # Broadcast to (M,) for the quantize-from-buffer kernel + M = x_bf16.shape[0] if gsa_gpu.dim() == 0: - gsa_gpu = gsa_gpu.reshape(1) # (1,) for M=1 + gsa_gpu = gsa_gpu.reshape(1).expand(M).contiguous() # (M,) all rows same gsa + elif gsa_gpu.shape[0] == 1 and M > 1: + gsa_gpu = gsa_gpu.expand(M).contiguous() quant_mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"]) x_fp4, x_sf = quant_mod.quantize_nvfp4_from_buffer(x_bf16, gsa_gpu) return x_fp4, x_sf, gsa_gpu