[CI/Build] Avoid CUDA initialization (#8534)
This commit is contained in:
@@ -8,8 +8,7 @@ from vllm.utils import is_cpu, is_hip
|
||||
from .utils import (dense_to_crow_col, get_head_sliding_step,
|
||||
get_sparse_attn_mask)
|
||||
|
||||
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
|
||||
and current_platform.get_device_capability()[0] >= 8)
|
||||
IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)
|
||||
|
||||
if IS_COMPUTE_8_OR_ABOVE:
|
||||
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
|
||||
@@ -36,7 +35,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
|
||||
use_spda = is_hip() or is_cpu() or not \
|
||||
IS_COMPUTE_8_OR_ABOVE
|
||||
device = device or (torch.cuda.current_device()
|
||||
if torch.cuda.is_available() else "cpu")
|
||||
if current_platform.is_cuda_alike() else "cpu")
|
||||
device = torch.device(device)
|
||||
# NOTE: vllm CPU backend support BF16 instead of FP16.
|
||||
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
|
||||
|
||||
Reference in New Issue
Block a user