fix: replace _allocate_buffers with _ensure_buffer_size for dynamic sizing

This commit is contained in:
2026-05-20 00:02:10 +00:00
parent 09669dded4
commit 268dc251c1

View File

@@ -74,23 +74,22 @@ class CuTeDSLNvfp4Linear:
self.sf = None
self.gs = None
def _allocate_buffers(self):
"""Pre-allocate buffers at max size for cudagraph compatibility."""
max_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128
def _ensure_buffer_size(self, num_tokens: int):
"""Ensure the padded buffer is large enough for num_tokens."""
needed_rows = cutedsl_ceil_div(num_tokens, 128) * 128
if self._padded_x_fp4_buf is not None and self._padded_x_fp4_buf.shape[0] >= needed_rows:
return # Already big enough
self._padded_x_fp4_buf = torch.zeros(
max_rows, self.in_features // 2, dtype=torch.uint8, device=self.device
needed_rows, self.in_features // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2)
self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device)
self._gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
self._buffers_allocated = True
def _ensure_initialized(self):
if self._mat_b is None:
self.finalize_weights()
if not self._buffers_allocated:
self._allocate_buffers()
def _assemble_scales_single_group(self, x_sf):
"""Assemble 2D-side activation scales for num_groups=1."""