[hardware][misc] introduce platform abstraction (#6080)

This commit is contained in:
youkaichao
2024-07-02 20:12:22 -07:00
committed by GitHub
parent 9d6a8daa87
commit 482045ee77
16 changed files with 113 additions and 29 deletions

View File

@@ -2,13 +2,14 @@ import math
import torch
from vllm.utils import get_device_capability_stateless, is_cpu, is_hip
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
and get_device_capability_stateless()[0] >= 8)
and current_platform.get_device_capability()[0] >= 8)
if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd