Share padded_x_sf and output buffers across layers to save ~300 MB

Per-layer padded_xsf (2.4 MB) + output_buf (4.2 MB) × 60 layers = ~400 MB.
Sharing reduces to ~3.6 MB total. Layers run sequentially during both
capture and replay.
This commit is contained in:
2026-05-17 16:05:53 +00:00
parent 3d0b1408b4
commit ea8acf9852

View File

@@ -119,14 +119,24 @@ class CuTeDSLMoERunner:
for _ in range(self.num_experts)
]
# Padded x_sf buffers: num_experts * max_chunks * 128 rows (fixed layout)
# Padded x_sf buffers: SHARED across all runners (not per-layer)
# Same reasoning as padded_hidden/activated — layers run sequentially.
max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128
self._padded_x_sf_buf_l1 = torch.zeros(
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
self._padded_x_sf_buf_l2 = torch.zeros(
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
if 'xsf_l1' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
'xsf_l1': torch.zeros(
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn),
'xsf_l2': torch.zeros(
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn),
'output': torch.zeros(
self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device
),
})
self._padded_x_sf_buf_l1 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l1']
self._padded_x_sf_buf_l2 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l2']
self._output_buf = CuTeDSLMoERunner._shared_padded_bufs[device_key]['output']
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)