wip: SE fused SwiGLU deinterleave fix
This commit is contained in:
@@ -27,6 +27,7 @@ from dsv4.ops.quantize import (
|
|||||||
from dsv4.ops.layouts import (
|
from dsv4.ops.layouts import (
|
||||||
make_b_k_major,
|
make_b_k_major,
|
||||||
interleave_l1_weights,
|
interleave_l1_weights,
|
||||||
|
deinterleave_l1_weights,
|
||||||
)
|
)
|
||||||
from dsv4.ops.gemm_runner import (
|
from dsv4.ops.gemm_runner import (
|
||||||
run_nvfp4_grouped_gemm,
|
run_nvfp4_grouped_gemm,
|
||||||
@@ -284,7 +285,11 @@ class Nvfp4SharedExpert:
|
|||||||
global_scale_b=self._l1_gsb,
|
global_scale_b=self._l1_gsb,
|
||||||
swiglu_limit=self.swiglu_limit if self.swiglu_limit is not None else 0.0,
|
swiglu_limit=self.swiglu_limit if self.swiglu_limit is not None else 0.0,
|
||||||
)
|
)
|
||||||
return l1_out[:num_tokens] # (num_tokens, intermediate_size) BF16, SwiGLU already applied
|
l1_out_real = l1_out[:num_tokens] # (num_tokens, 2*intermediate) BF16, interleaved [silu(gate), silu(gate)*up]
|
||||||
|
# Deinterleave to separate gate and up, then take up half (SwiGLU result)
|
||||||
|
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0] # (num_tokens, 2*intermediate) deinterleaved
|
||||||
|
intermediate = l1_deil[:, self.intermediate_size:] # up half = silu(gate)*up
|
||||||
|
return intermediate # (num_tokens, intermediate_size) BF16
|
||||||
|
|
||||||
def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
"""L1 GEMM: activation × gate_up_weight → BF16."""
|
"""L1 GEMM: activation × gate_up_weight → BF16."""
|
||||||
|
|||||||
Reference in New Issue
Block a user