[CI] Bump mypy version (#34950)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -57,11 +57,11 @@ def opcheck_fp8_quant(
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("do_scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_token_fp8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, do_scale_ub: bool, seed: int
|
||||
) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
@@ -70,7 +70,7 @@ def test_dynamic_per_token_fp8_quant(
|
||||
) # avoid nans
|
||||
|
||||
scale_ub = (
|
||||
torch.mean(x).to(dtype=torch.float32, device="cuda") if scale_ub else None
|
||||
torch.mean(x).to(dtype=torch.float32, device="cuda") if do_scale_ub else None
|
||||
)
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
|
||||
ops_out, ops_scales = ops.scaled_fp8_quant(
|
||||
|
||||
Reference in New Issue
Block a user