[v1] Re-add fp32 support to v1 engine through FlexAttention (#19754)

Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-07-05 17:41:10 +08:00
committed by GitHub
parent 8aeaa910a2
commit 32c9be2200
8 changed files with 59 additions and 12 deletions

View File

@@ -463,6 +463,13 @@ class FlexAttentionImpl(AttentionImpl):
query = query[:, :, :num_actual_tokens, :]
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)
# default M=64, N=64 may run out of shared memory on
# some GPUs with fp32, so we use smaller M and N.
extra_kernel_options = {
"BLOCK_M": 32,
"BLOCK_N": 32
} if query.dtype == torch.float32 else {}
out = flex_attention_compiled(
query,
key_cache,
@@ -471,7 +478,10 @@ class FlexAttentionImpl(AttentionImpl):
attn_metadata.block_mask,
self.scale,
enable_gqa=enable_gqa,
kernel_options={"FORCE_USE_FLEX_ATTENTION": True},
kernel_options={
"FORCE_USE_FLEX_ATTENTION": True,
**extra_kernel_options
},
)
# Flex doesn't have an out variant today, rely on epilogue fusion