diff --git a/dsv4/layers/moe.py b/dsv4/layers/moe.py index 0dc0e89e..8743f40c 100644 --- a/dsv4/layers/moe.py +++ b/dsv4/layers/moe.py @@ -166,9 +166,9 @@ class Nvfp4MoE: self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device) # Pre-allocated L1 GEMM output — avoids torch.zeros() in run_fused_swiglu_grouped_gemm - # Shape: (max_tokens * top_k, intermediate_size) — max possible L1 output + # Shape: (max_tokens * top_k, 2*intermediate_size) — gate+up combined self._l1_out_buf = torch.zeros( - self.max_num_tokens * self.top_k, self.intermediate_size, + self.max_num_tokens * self.top_k, 2 * self.intermediate_size, dtype=torch.bfloat16, device=self.device ) diff --git a/dsv4/layers/shared_expert.py b/dsv4/layers/shared_expert.py index af3951fa..7a02e58c 100644 --- a/dsv4/layers/shared_expert.py +++ b/dsv4/layers/shared_expert.py @@ -184,8 +184,9 @@ class Nvfp4SharedExpert: self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device) # Pre-allocated L1 output buffer for graph capture + # L1 produces gate+up combined: 2 * intermediate_size BF16 columns self._l1_out_buf = torch.zeros( - max_rows, self.intermediate_size, + max_rows, 2 * self.intermediate_size, dtype=torch.bfloat16, device=self.device ) @@ -365,25 +366,8 @@ class Nvfp4SharedExpert: from dsv4.ops.quantize import quantize_nvfp4_gpu_fused if not intermediate.is_contiguous(): intermediate = intermediate.contiguous() - # DEBUG: isolate async CUDA error - torch.cuda.synchronize() # catch any prior async error x_fp4, x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(intermediate) - try: - torch.cuda.synchronize() # catch error from quantize kernels - except RuntimeError as e: - print(f" SE L2: quantize_nvfp4_gpu_fused FAILED after sync: {e}", flush=True) - print(f" intermediate: shape={tuple(intermediate.shape)} dtype={intermediate.dtype} dev={intermediate.device}", flush=True) - raise - # DEBUG: check gsa values before assignment - try: - gsa_first = gsa_l2_gpu[0].item() # DEBUG: read value - print(f" SE L2 gsa[0]={gsa_first:.6f} shape={tuple(gsa_l2_gpu.shape)} dev={gsa_l2_gpu.device} buf_dev={self._l2_gsa_buf.device} buf_shape={tuple(self._l2_gsa_buf.shape)}", flush=True) - # Try copy_ instead of scalar assign - self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].contiguous()) - print(f" SE L2 gsa copy_ succeeded", flush=True) - except RuntimeError as e: - print(f" SE L2: gsa assignment FAILED: {e}", flush=True) - raise + self._l2_gsa_buf[0] = gsa_l2_gpu[0] # scalar GPU→GPU, no sync, graph-capturable else: x_fp4, x_sf = quantize_activation_nvfp4( intermediate, self._l2_activation_global_scale @@ -425,15 +409,11 @@ class Nvfp4SharedExpert: """Actual implementation — called via custom autograd to be torch.compile-safe.""" self._ensure_initialized() - # DEBUG: check input - print(f" SE input: shape={tuple(hidden_states.shape)} |max|={hidden_states.abs().max().item():.6f} nan={torch.isnan(hidden_states).any().item()}", flush=True) - if self._fused_swiglu: # P1: Fused L1 GEMM + SwiGLU + clamp in one kernel launch intermediate = self._run_l1_fused(hidden_states) else: l1_out = self._run_l1(hidden_states) - print(f" SE L1 out: shape={tuple(l1_out.shape)} |max|={l1_out.abs().max().item() if l1_out.numel() > 0 else 'EMPTY'} nan={torch.isnan(l1_out).any().item() if l1_out.numel() > 0 else 'N/A'}", flush=True) if l1_out.shape[1] < 2 * self.intermediate_size: print(f" WARNING: l1_out shape {l1_out.shape} < expected (N, {2*self.intermediate_size})", flush=True) @@ -447,8 +427,6 @@ class Nvfp4SharedExpert: gate = gate.clamp(max=self.swiglu_limit) up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) intermediate = torch.nn.functional.silu(gate) * up - # DEBUG: check intermediate before L2 - print(f" SE intermediate: shape={tuple(intermediate.shape)} |max|={intermediate.abs().max().item():.6f} nan={torch.isnan(intermediate).any().item()}", flush=True) output = self._run_l2(intermediate) return output diff --git a/dsv4/ops/quantize.py b/dsv4/ops/quantize.py index 10a51388..3a189555 100644 --- a/dsv4/ops/quantize.py +++ b/dsv4/ops/quantize.py @@ -334,8 +334,6 @@ def quantize_nvfp4_gpu_fused(x_bf16, divisor=6.0 * 448.0): # For M=1: gsa_gpu is (1,) contiguous — zero allocation quant_mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"]) x_fp4, x_sf = quant_mod.quantize_nvfp4_from_buffer(x_bf16, gsa_gpu) - # DEBUG: sync to catch async errors from the quantize kernels - torch.cuda.synchronize() return x_fp4, x_sf, gsa_gpu