[BUGFIX] Fix degenerate strides in TRTLLM query tensors for FlashInfer backend. Fixes issue #32353 (#32417)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2026-01-19 04:57:32 +04:00
committed by GitHub
parent f5d1740030
commit 6101a26dc9

View File

@@ -1385,8 +1385,11 @@ class FlashInferImpl(AttentionImpl):
)
else:
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
# prefill_query may be non-contiguous
prefill_query = prefill_query.contiguous()
# prefill_query may be non-contiguous or have degenerate strides
# First ensure memory contiguity, then fix degenerate strides
# with reshape. contiguous() alone doesn't fix degenerate
# strides when a dimension has size 1.
prefill_query = prefill_query.contiguous().reshape(prefill_query.shape)
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_prefill = attn_metadata.prefill.block_tables
seq_lens_prefill = attn_metadata.prefill.seq_lens
@@ -1495,9 +1498,12 @@ class FlashInferImpl(AttentionImpl):
out=output[:num_decode_tokens],
)
else:
# decode_query may be non-contiguous
# decode_query may be non-contiguous or have degenerate strides
assert isinstance(attn_metadata.decode, TRTLLMDecode)
decode_query = decode_query.contiguous()
# First ensure memory contiguity, then fix degenerate strides
# with reshape. contiguous() alone doesn't fix degenerate
# strides when a dimension has size 1.
decode_query = decode_query.contiguous().reshape(decode_query.shape)
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_decode = attn_metadata.decode.block_tables
seq_lens_decode = attn_metadata.decode.seq_lens