fix: replace _allocate_buffers with _ensure_buffer_size for dynamic sizing
This commit is contained in:
@@ -74,23 +74,22 @@ class CuTeDSLNvfp4Linear:
|
||||
self.sf = None
|
||||
self.gs = None
|
||||
|
||||
def _allocate_buffers(self):
|
||||
"""Pre-allocate buffers at max size for cudagraph compatibility."""
|
||||
max_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128
|
||||
def _ensure_buffer_size(self, num_tokens: int):
|
||||
"""Ensure the padded buffer is large enough for num_tokens."""
|
||||
needed_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
if self._padded_x_fp4_buf is not None and self._padded_x_fp4_buf.shape[0] >= needed_rows:
|
||||
return # Already big enough
|
||||
|
||||
self._padded_x_fp4_buf = torch.zeros(
|
||||
max_rows, self.in_features // 2, dtype=torch.uint8, device=self.device
|
||||
needed_rows, self.in_features // 2, dtype=torch.uint8, device=self.device
|
||||
).view(torch.float4_e2m1fn_x2)
|
||||
|
||||
self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device)
|
||||
self._gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
|
||||
self._buffers_allocated = True
|
||||
|
||||
def _ensure_initialized(self):
|
||||
if self._mat_b is None:
|
||||
self.finalize_weights()
|
||||
if not self._buffers_allocated:
|
||||
self._allocate_buffers()
|
||||
|
||||
def _assemble_scales_single_group(self, x_sf):
|
||||
"""Assemble 2D-side activation scales for num_groups=1."""
|
||||
|
||||
Reference in New Issue
Block a user