[CI/Build] Avoid CUDA initialization (#8534)
This commit is contained in:
@@ -4,6 +4,8 @@ import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import seed_everything
|
||||
|
||||
NUM_HEADS = [(16, 16), (32, 8), (64, 8), (6, 1)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
@@ -82,7 +84,7 @@ def test_flashinfer_decode_with_paged_kv(
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
@@ -168,7 +170,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
||||
block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
@@ -266,7 +268,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
head_size: int, dtype: torch.dtype, block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
@@ -379,7 +381,7 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
) -> None:
|
||||
# test doesn't work for num_heads = (16,16)
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
|
||||
Reference in New Issue
Block a user