[Bugfix] Fix TP inference for Flex attention backend (#19657)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-06-16 19:21:37 +08:00
committed by GitHub
parent 4d5424029b
commit 1173804dca
5 changed files with 54 additions and 2 deletions

View File

@@ -13,6 +13,7 @@ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
@@ -236,7 +237,12 @@ class FlexAttentionMetadata:
def build_block_mask(self) -> BlockMask:
assert self.mask_mod is not None
return create_block_mask_compiled(
# FIXME: With TP>1, create_block_mask_compiled will raise
# CUDA error: an illegal memory access was encountered
create_block_mask_fn = (create_block_mask_compiled
if get_tensor_model_parallel_world_size() == 1
else create_block_mask)
return create_block_mask_fn(
self.mask_mod,
None,
None,