[Performance][B200] silu_mul_quant: pack scales in int32 (#28358)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-11-13 13:16:55 -05:00
committed by GitHub
parent fdfd5075aa
commit fe1cd7704d
7 changed files with 466 additions and 151 deletions

View File

@@ -294,7 +294,7 @@ def torch_moe_impl(
# blockwise quant and de-quant.
assert not per_act_token_quant
a = test_tensors.rank_tokens
aq, aq_scale = per_token_group_quant_fp8(a, 128)
aq, aq_scale = per_token_group_quant_fp8(a, 128, use_ue8m0=False)
a = (
(aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1))
.view(a.shape)