wip: SE fused SwiGLU deinterleave fix

This commit is contained in:
2026-06-02 08:41:00 +00:00
parent 1726cb64a9
commit f01d3f3eac

View File

@@ -27,6 +27,7 @@ from dsv4.ops.quantize import (
from dsv4.ops.layouts import (
make_b_k_major,
interleave_l1_weights,
deinterleave_l1_weights,
)
from dsv4.ops.gemm_runner import (
run_nvfp4_grouped_gemm,
@@ -284,7 +285,11 @@ class Nvfp4SharedExpert:
global_scale_b=self._l1_gsb,
swiglu_limit=self.swiglu_limit if self.swiglu_limit is not None else 0.0,
)
return l1_out[:num_tokens] # (num_tokens, intermediate_size) BF16, SwiGLU already applied
l1_out_real = l1_out[:num_tokens] # (num_tokens, 2*intermediate) BF16, interleaved [silu(gate), silu(gate)*up]
# Deinterleave to separate gate and up, then take up half (SwiGLU result)
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0] # (num_tokens, 2*intermediate) deinterleaved
intermediate = l1_deil[:, self.intermediate_size:] # up half = silu(gate)*up
return intermediate # (num_tokens, intermediate_size) BF16
def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""L1 GEMM: activation × gate_up_weight → BF16."""