[MISC] Add strict contiguity check for FlashInfer attention tensors (#32008)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
This commit is contained in:
@@ -64,6 +64,36 @@ MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def is_strictly_contiguous(t: torch.Tensor) -> bool:
|
||||
"""
|
||||
Check if tensor is contiguous AND has no degenerate strides.
|
||||
|
||||
A degenerate stride occurs when a dimension has size 1 but the stride
|
||||
doesn't match the canonical contiguous layout. This can cause issues
|
||||
in some CUDA kernels that rely on stride values for memory access.
|
||||
|
||||
For a C-contiguous tensor of shape (d0, d1, ..., dn), the expected
|
||||
strides are: stride[i] = product(shape[i+1:]) for all i, with stride[-1]=1.
|
||||
|
||||
Example with torch.Size([16, 1, 8, 32]):
|
||||
- Canonical strides: (256, 256, 32, 1)
|
||||
- Degenerate strides: (256, 1, 32, 1) # dim=1 has size=1, allowing
|
||||
# non-canonical stride in dim=0
|
||||
"""
|
||||
if not t.is_contiguous():
|
||||
return False
|
||||
|
||||
# Check that strides match canonical contiguous layout
|
||||
shape = t.shape
|
||||
strides = t.stride()
|
||||
expected_stride = 1
|
||||
for i in range(len(shape) - 1, -1, -1):
|
||||
if strides[i] != expected_stride:
|
||||
return False
|
||||
expected_stride *= shape[i]
|
||||
return True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
|
||||
@@ -40,6 +40,7 @@ from vllm.utils.flashinfer import (
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import is_strictly_contiguous
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
@@ -1392,11 +1393,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
assert prefill_query.is_contiguous()
|
||||
assert kv_cache_permute.is_contiguous()
|
||||
assert workspace_buffer.is_contiguous()
|
||||
assert block_tables_prefill.is_contiguous()
|
||||
assert seq_lens_prefill.is_contiguous()
|
||||
assert is_strictly_contiguous(prefill_query)
|
||||
assert is_strictly_contiguous(kv_cache_permute)
|
||||
assert is_strictly_contiguous(workspace_buffer)
|
||||
assert is_strictly_contiguous(block_tables_prefill)
|
||||
assert is_strictly_contiguous(seq_lens_prefill)
|
||||
|
||||
if output.dtype == FP4_DTYPE:
|
||||
assert self.o_sf_scale is not None
|
||||
@@ -1503,11 +1504,11 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
assert decode_query.is_contiguous()
|
||||
assert kv_cache_permute.is_contiguous()
|
||||
assert workspace_buffer.is_contiguous()
|
||||
assert block_tables_decode.is_contiguous()
|
||||
assert seq_lens_decode.is_contiguous()
|
||||
assert is_strictly_contiguous(decode_query)
|
||||
assert is_strictly_contiguous(kv_cache_permute)
|
||||
assert is_strictly_contiguous(workspace_buffer)
|
||||
assert is_strictly_contiguous(block_tables_decode)
|
||||
assert is_strictly_contiguous(seq_lens_decode)
|
||||
|
||||
if output.dtype == FP4_DTYPE:
|
||||
assert self.o_sf_scale is not None
|
||||
|
||||
Reference in New Issue
Block a user