Share padded_x_sf and output buffers across layers to save ~300 MB
Per-layer padded_xsf (2.4 MB) + output_buf (4.2 MB) × 60 layers = ~400 MB. Sharing reduces to ~3.6 MB total. Layers run sequentially during both capture and replay.
This commit is contained in:
@@ -119,14 +119,24 @@ class CuTeDSLMoERunner:
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
|
||||
# Padded x_sf buffers: num_experts * max_chunks * 128 rows (fixed layout)
|
||||
# Padded x_sf buffers: SHARED across all runners (not per-layer)
|
||||
# Same reasoning as padded_hidden/activated — layers run sequentially.
|
||||
max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128
|
||||
self._padded_x_sf_buf_l1 = torch.zeros(
|
||||
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
self._padded_x_sf_buf_l2 = torch.zeros(
|
||||
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
if 'xsf_l1' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
|
||||
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
|
||||
'xsf_l1': torch.zeros(
|
||||
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn),
|
||||
'xsf_l2': torch.zeros(
|
||||
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn),
|
||||
'output': torch.zeros(
|
||||
self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
})
|
||||
self._padded_x_sf_buf_l1 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l1']
|
||||
self._padded_x_sf_buf_l2 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l2']
|
||||
self._output_buf = CuTeDSLMoERunner._shared_padded_bufs[device_key]['output']
|
||||
|
||||
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
|
||||
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
||||
|
||||
Reference in New Issue
Block a user