[v1] Re-add fp32 support to v1 engine through FlexAttention (#19754)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -463,6 +463,13 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
query = query[:, :, :num_actual_tokens, :]
|
||||
# Doesn't work for now -> constraint violation
|
||||
# torch._dynamo.try_mark_dynamic(query, 2)
|
||||
|
||||
# default M=64, N=64 may run out of shared memory on
|
||||
# some GPUs with fp32, so we use smaller M and N.
|
||||
extra_kernel_options = {
|
||||
"BLOCK_M": 32,
|
||||
"BLOCK_N": 32
|
||||
} if query.dtype == torch.float32 else {}
|
||||
out = flex_attention_compiled(
|
||||
query,
|
||||
key_cache,
|
||||
@@ -471,7 +478,10 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
attn_metadata.block_mask,
|
||||
self.scale,
|
||||
enable_gqa=enable_gqa,
|
||||
kernel_options={"FORCE_USE_FLEX_ATTENTION": True},
|
||||
kernel_options={
|
||||
"FORCE_USE_FLEX_ATTENTION": True,
|
||||
**extra_kernel_options
|
||||
},
|
||||
)
|
||||
|
||||
# Flex doesn't have an out variant today, rely on epilogue fusion
|
||||
|
||||
Reference in New Issue
Block a user