wip: SE fused SwiGLU deinterleave fix
This commit is contained in:
@@ -27,6 +27,7 @@ from dsv4.ops.quantize import (
|
||||
from dsv4.ops.layouts import (
|
||||
make_b_k_major,
|
||||
interleave_l1_weights,
|
||||
deinterleave_l1_weights,
|
||||
)
|
||||
from dsv4.ops.gemm_runner import (
|
||||
run_nvfp4_grouped_gemm,
|
||||
@@ -284,7 +285,11 @@ class Nvfp4SharedExpert:
|
||||
global_scale_b=self._l1_gsb,
|
||||
swiglu_limit=self.swiglu_limit if self.swiglu_limit is not None else 0.0,
|
||||
)
|
||||
return l1_out[:num_tokens] # (num_tokens, intermediate_size) BF16, SwiGLU already applied
|
||||
l1_out_real = l1_out[:num_tokens] # (num_tokens, 2*intermediate) BF16, interleaved [silu(gate), silu(gate)*up]
|
||||
# Deinterleave to separate gate and up, then take up half (SwiGLU result)
|
||||
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0] # (num_tokens, 2*intermediate) deinterleaved
|
||||
intermediate = l1_deil[:, self.intermediate_size:] # up half = silu(gate)*up
|
||||
return intermediate # (num_tokens, intermediate_size) BF16
|
||||
|
||||
def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""L1 GEMM: activation × gate_up_weight → BF16."""
|
||||
|
||||
Reference in New Issue
Block a user