[CI/Build] Avoid CUDA initialization (#8534)
This commit is contained in:
@@ -15,9 +15,6 @@ CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
@@ -119,7 +116,7 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
@@ -157,7 +154,7 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||
out_dtype: Type[torch.dtype],
|
||||
@@ -175,7 +172,7 @@ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||
use_bias: bool, device: str):
|
||||
@@ -207,7 +204,7 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(capability < 89,
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
|
||||
use_bias: bool):
|
||||
|
||||
Reference in New Issue
Block a user