diff --git a/dsv4/layers/shared_expert.py b/dsv4/layers/shared_expert.py index 2f92808e..eb824744 100644 --- a/dsv4/layers/shared_expert.py +++ b/dsv4/layers/shared_expert.py @@ -27,6 +27,7 @@ from dsv4.ops.quantize import ( from dsv4.ops.layouts import ( make_b_k_major, interleave_l1_weights, + deinterleave_l1_weights, ) from dsv4.ops.gemm_runner import ( run_nvfp4_grouped_gemm, @@ -284,7 +285,11 @@ class Nvfp4SharedExpert: global_scale_b=self._l1_gsb, swiglu_limit=self.swiglu_limit if self.swiglu_limit is not None else 0.0, ) - return l1_out[:num_tokens] # (num_tokens, intermediate_size) BF16, SwiGLU already applied + l1_out_real = l1_out[:num_tokens] # (num_tokens, 2*intermediate) BF16, interleaved [silu(gate), silu(gate)*up] + # Deinterleave to separate gate and up, then take up half (SwiGLU result) + l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0] # (num_tokens, 2*intermediate) deinterleaved + intermediate = l1_deil[:, self.intermediate_size:] # up half = silu(gate)*up + return intermediate # (num_tokens, intermediate_size) BF16 def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor: """L1 GEMM: activation × gate_up_weight → BF16."""