[Bugfix] Fix persistent_masked_m_silu_mul_quant tests (#28366)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-11-10 12:21:52 -05:00
committed by GitHub
parent d0e186c16f
commit b039bfda8f
3 changed files with 16 additions and 7 deletions

View File

@@ -100,6 +100,7 @@ def persistent_masked_m_silu_mul_quant(
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
num_parallel_tokens=16,
group_size: int = 128,
use_ue8m0: bool | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
@@ -164,7 +165,7 @@ def persistent_masked_m_silu_mul_quant(
device=y.device,
)
use_ue8m0 = is_deep_gemm_e8m0_used()
use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used()
cuda_arch = current_platform.get_device_capability(
device_id=y.device.index