diff --git a/cutedsl/nvfp4_linear.py b/cutedsl/nvfp4_linear.py index d25d04c1..ec9d44a2 100644 --- a/cutedsl/nvfp4_linear.py +++ b/cutedsl/nvfp4_linear.py @@ -74,23 +74,22 @@ class CuTeDSLNvfp4Linear: self.sf = None self.gs = None - def _allocate_buffers(self): - """Pre-allocate buffers at max size for cudagraph compatibility.""" - max_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128 + def _ensure_buffer_size(self, num_tokens: int): + """Ensure the padded buffer is large enough for num_tokens.""" + needed_rows = cutedsl_ceil_div(num_tokens, 128) * 128 + if self._padded_x_fp4_buf is not None and self._padded_x_fp4_buf.shape[0] >= needed_rows: + return # Already big enough self._padded_x_fp4_buf = torch.zeros( - max_rows, self.in_features // 2, dtype=torch.uint8, device=self.device + needed_rows, self.in_features // 2, dtype=torch.uint8, device=self.device ).view(torch.float4_e2m1fn_x2) self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device) self._gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device) - self._buffers_allocated = True def _ensure_initialized(self): if self._mat_b is None: self.finalize_weights() - if not self._buffers_allocated: - self._allocate_buffers() def _assemble_scales_single_group(self, x_sf): """Assemble 2D-side activation scales for num_groups=1."""