CUDA graph: Fix gsa broadcast — contiguous for prefill, reshape for decode
The stride-0 expand view for gsa_gpu caused illegal memory access in quantize_nvfp4_from_buffer kernel. The CUDA kernel may not handle stride-0 tensors correctly. Fix: - M=1 decode (graph-captured): just reshape scalar to (1,) — no alloc - M>1 prefill (not graph-captured): expand + contiguous — allocation OK
This commit is contained in:
@@ -323,18 +323,16 @@ def quantize_nvfp4_gpu_fused(x_bf16, divisor=6.0 * 448.0):
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
amax_mod = get_cuda_module("amax_gsa", ["amax_gsa.cu"])
|
||||
gsa_gpu = amax_mod.compute_amax_gsa(x_bf16, divisor) # scalar GPU tensor
|
||||
# Broadcast to (M,) for the quantize-from-buffer kernel
|
||||
# CUDA-graph-safe: use reshape+expand without .contiguous() allocation.
|
||||
# For M=1 decode (the common graph-captured case), gsa is already scalar — no alloc.
|
||||
# For M>1 prefill (not graph-captured), expand creates a view, and the CUDA kernel
|
||||
# reads it correctly because the underlying data is contiguous (single value expanded).
|
||||
# If the kernel truly requires physical contiguity, the caller should pre-allocate
|
||||
# a buffer and use copy_ instead.
|
||||
# Broadcast to (M,) for the quantize-from-buffer kernel.
|
||||
# CUDA-graph-safe approach:
|
||||
# - For M=1 decode (graph-captured): just reshape to (1,) — no allocation.
|
||||
# - For M>1 prefill (not graph-captured): expand + contiguous is fine.
|
||||
M = x_bf16.shape[0]
|
||||
if gsa_gpu.dim() == 0:
|
||||
gsa_gpu = gsa_gpu.reshape(1).expand(M) # (M,) view — no allocation
|
||||
elif gsa_gpu.shape[0] == 1 and M > 1:
|
||||
gsa_gpu = gsa_gpu.expand(M) # view — no allocation
|
||||
gsa_gpu = gsa_gpu.reshape(1) # scalar → (1,) — no allocation
|
||||
if M > 1:
|
||||
gsa_gpu = gsa_gpu.expand(M).contiguous() # (M,) — allocation OK for prefill
|
||||
# For M=1: gsa_gpu is (1,) contiguous — zero allocation
|
||||
quant_mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
|
||||
x_fp4, x_sf = quant_mod.quantize_nvfp4_from_buffer(x_bf16, gsa_gpu)
|
||||
return x_fp4, x_sf, gsa_gpu
|
||||
|
||||
Reference in New Issue
Block a user