[Feat] Support non-gated activations in NVFP4 modelopt path (#29004)

This commit is contained in:
Omer Ullman Argov
2025-11-30 18:02:40 +02:00
committed by GitHub
parent cd719de5cb
commit 39d28108f4
5 changed files with 98 additions and 22 deletions

View File

@@ -264,13 +264,20 @@ def make_test_weights(
quant_dtype: torch.dtype | str | None = None,
block_shape: list[int] | None = None,
per_out_ch_quant: bool = False,
make_gate: bool = True,
) -> tuple[
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
]:
return (
make_test_weight(
e, 2 * n, k, in_dtype, quant_dtype, block_shape, per_out_ch_quant
e,
(2 if make_gate else 1) * n,
k,
in_dtype,
quant_dtype,
block_shape,
per_out_ch_quant,
),
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
)
@@ -297,6 +304,7 @@ def make_test_quant_config(
quant_dtype: torch.dtype | str | None = None,
per_act_token_quant: bool = False,
block_shape: list[int] | None = None,
make_gate: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
e,
@@ -306,6 +314,7 @@ def make_test_quant_config(
quant_dtype,
per_out_ch_quant=per_act_token_quant,
block_shape=block_shape,
make_gate=make_gate,
)
# Hacky/trivial scales for nvfp4.