[hardware][misc] introduce platform abstraction (#6080)
This commit is contained in:
@@ -2,13 +2,14 @@ import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import get_device_capability_stateless, is_cpu, is_hip
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
|
||||
from .utils import (dense_to_crow_col, get_head_sliding_step,
|
||||
get_sparse_attn_mask)
|
||||
|
||||
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
|
||||
and get_device_capability_stateless()[0] >= 8)
|
||||
and current_platform.get_device_capability()[0] >= 8)
|
||||
|
||||
if IS_COMPUTE_8_OR_ABOVE:
|
||||
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
|
||||
|
||||
Reference in New Issue
Block a user