diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index a57dc111..0a46718c 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -119,14 +119,24 @@ class CuTeDSLMoERunner: for _ in range(self.num_experts) ] - # Padded x_sf buffers: num_experts * max_chunks * 128 rows (fixed layout) + # Padded x_sf buffers: SHARED across all runners (not per-layer) + # Same reasoning as padded_hidden/activated — layers run sequentially. max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128 - self._padded_x_sf_buf_l1 = torch.zeros( - max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device - ).to(torch.float8_e4m3fn) - self._padded_x_sf_buf_l2 = torch.zeros( - max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device - ).to(torch.float8_e4m3fn) + if 'xsf_l1' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]: + CuTeDSLMoERunner._shared_padded_bufs[device_key].update({ + 'xsf_l1': torch.zeros( + max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device + ).to(torch.float8_e4m3fn), + 'xsf_l2': torch.zeros( + max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device + ).to(torch.float8_e4m3fn), + 'output': torch.zeros( + self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device + ), + }) + self._padded_x_sf_buf_l1 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l1'] + self._padded_x_sf_buf_l2 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l2'] + self._output_buf = CuTeDSLMoERunner._shared_padded_bufs[device_key]['output'] # Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture) self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)