Fix gsa_buffer shape mismatch for MoE (M>1 rows)
compute_amax_gsa returns a scalar, but quantize_from_buffer expects (M,). Broadcast the scalar gsa to (M,) — all rows use the same gsa (global max).
This commit is contained in:
@@ -269,8 +269,11 @@ def deinterleave_amax_quantize_nvfp4_fused(fused_bf16, intermediate, divisor=6.0
|
||||
# Compute gsa from the fused output
|
||||
amax_mod = get_cuda_module("amax_gsa", ["amax_gsa.cu"])
|
||||
gsa_gpu = amax_mod.compute_amax_gsa(fused_bf16, divisor)
|
||||
M = fused_bf16.shape[0]
|
||||
if gsa_gpu.dim() == 0:
|
||||
gsa_gpu = gsa_gpu.reshape(1)
|
||||
gsa_gpu = gsa_gpu.reshape(1).expand(M).contiguous()
|
||||
elif gsa_gpu.shape[0] == 1 and M > 1:
|
||||
gsa_gpu = gsa_gpu.expand(M).contiguous()
|
||||
# Deinterleave + quantize using gsa from GPU buffer
|
||||
quant_mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
|
||||
x_fp4, x_sf = quant_mod.deinterleave_quantize_from_buffer(fused_bf16, intermediate, granularity, gsa_gpu)
|
||||
@@ -314,10 +317,13 @@ def quantize_nvfp4_gpu_fused(x_bf16, divisor=6.0 * 448.0):
|
||||
"""
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
amax_mod = get_cuda_module("amax_gsa", ["amax_gsa.cu"])
|
||||
gsa_gpu = amax_mod.compute_amax_gsa(x_bf16, divisor) # scalar GPU tensor for M=1
|
||||
# Reshape to (M,) for the quantize-from-buffer kernel
|
||||
gsa_gpu = amax_mod.compute_amax_gsa(x_bf16, divisor) # scalar GPU tensor
|
||||
# Broadcast to (M,) for the quantize-from-buffer kernel
|
||||
M = x_bf16.shape[0]
|
||||
if gsa_gpu.dim() == 0:
|
||||
gsa_gpu = gsa_gpu.reshape(1) # (1,) for M=1
|
||||
gsa_gpu = gsa_gpu.reshape(1).expand(M).contiguous() # (M,) all rows same gsa
|
||||
elif gsa_gpu.shape[0] == 1 and M > 1:
|
||||
gsa_gpu = gsa_gpu.expand(M).contiguous()
|
||||
quant_mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
|
||||
x_fp4, x_sf = quant_mod.quantize_nvfp4_from_buffer(x_bf16, gsa_gpu)
|
||||
return x_fp4, x_sf, gsa_gpu
|
||||
|
||||
Reference in New Issue
Block a user