[Bugfix] Fix persistent_masked_m_silu_mul_quant tests (#28366)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
d0e186c16f
commit
b039bfda8f
@@ -100,6 +100,7 @@ def persistent_masked_m_silu_mul_quant(
|
||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||
num_parallel_tokens=16,
|
||||
group_size: int = 128,
|
||||
use_ue8m0: bool | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||
@@ -164,7 +165,7 @@ def persistent_masked_m_silu_mul_quant(
|
||||
device=y.device,
|
||||
)
|
||||
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
use_ue8m0 = use_ue8m0 if use_ue8m0 is not None else is_deep_gemm_e8m0_used()
|
||||
|
||||
cuda_arch = current_platform.get_device_capability(
|
||||
device_id=y.device.index
|
||||
|
||||
Reference in New Issue
Block a user