diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 9892c360d..7a0aff80e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -1385,8 +1385,11 @@ class FlashInferImpl(AttentionImpl): ) else: assert isinstance(attn_metadata.prefill, TRTLLMPrefill) - # prefill_query may be non-contiguous - prefill_query = prefill_query.contiguous() + # prefill_query may be non-contiguous or have degenerate strides + # First ensure memory contiguity, then fix degenerate strides + # with reshape. contiguous() alone doesn't fix degenerate + # strides when a dimension has size 1. + prefill_query = prefill_query.contiguous().reshape(prefill_query.shape) workspace_buffer = _get_trtllm_gen_workspace_buffer() block_tables_prefill = attn_metadata.prefill.block_tables seq_lens_prefill = attn_metadata.prefill.seq_lens @@ -1495,9 +1498,12 @@ class FlashInferImpl(AttentionImpl): out=output[:num_decode_tokens], ) else: - # decode_query may be non-contiguous + # decode_query may be non-contiguous or have degenerate strides assert isinstance(attn_metadata.decode, TRTLLMDecode) - decode_query = decode_query.contiguous() + # First ensure memory contiguity, then fix degenerate strides + # with reshape. contiguous() alone doesn't fix degenerate + # strides when a dimension has size 1. + decode_query = decode_query.contiguous().reshape(decode_query.shape) workspace_buffer = _get_trtllm_gen_workspace_buffer() block_tables_decode = attn_metadata.decode.block_tables seq_lens_decode = attn_metadata.decode.seq_lens