[MISC] Add strict contiguity check for FlashInfer attention tensors (#32008)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
This commit is contained in:
Vadim Gimpelson
2026-01-11 00:40:05 +04:00
committed by GitHub
parent 6ea001cfb7
commit e15a5ff07b
2 changed files with 41 additions and 10 deletions

View File

@@ -64,6 +64,36 @@ MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
T = TypeVar("T")
def is_strictly_contiguous(t: torch.Tensor) -> bool:
"""
Check if tensor is contiguous AND has no degenerate strides.
A degenerate stride occurs when a dimension has size 1 but the stride
doesn't match the canonical contiguous layout. This can cause issues
in some CUDA kernels that rely on stride values for memory access.
For a C-contiguous tensor of shape (d0, d1, ..., dn), the expected
strides are: stride[i] = product(shape[i+1:]) for all i, with stride[-1]=1.
Example with torch.Size([16, 1, 8, 32]):
- Canonical strides: (256, 256, 32, 1)
- Degenerate strides: (256, 1, 32, 1) # dim=1 has size=1, allowing
# non-canonical stride in dim=0
"""
if not t.is_contiguous():
return False
# Check that strides match canonical contiguous layout
shape = t.shape
strides = t.stride()
expected_stride = 1
for i in range(len(shape) - 1, -1, -1):
if strides[i] != expected_stride:
return False
expected_stride *= shape[i]
return True
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""

View File

@@ -40,6 +40,7 @@ from vllm.utils.flashinfer import (
)
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import is_strictly_contiguous
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
@@ -1392,11 +1393,11 @@ class FlashInferImpl(AttentionImpl):
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND"
assert prefill_query.is_contiguous()
assert kv_cache_permute.is_contiguous()
assert workspace_buffer.is_contiguous()
assert block_tables_prefill.is_contiguous()
assert seq_lens_prefill.is_contiguous()
assert is_strictly_contiguous(prefill_query)
assert is_strictly_contiguous(kv_cache_permute)
assert is_strictly_contiguous(workspace_buffer)
assert is_strictly_contiguous(block_tables_prefill)
assert is_strictly_contiguous(seq_lens_prefill)
if output.dtype == FP4_DTYPE:
assert self.o_sf_scale is not None
@@ -1503,11 +1504,11 @@ class FlashInferImpl(AttentionImpl):
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND"
assert decode_query.is_contiguous()
assert kv_cache_permute.is_contiguous()
assert workspace_buffer.is_contiguous()
assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous()
assert is_strictly_contiguous(decode_query)
assert is_strictly_contiguous(kv_cache_permute)
assert is_strictly_contiguous(workspace_buffer)
assert is_strictly_contiguous(block_tables_decode)
assert is_strictly_contiguous(seq_lens_decode)
if output.dtype == FP4_DTYPE:
assert self.o_sf_scale is not None