remove attn output view kernel (#26680)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -346,7 +346,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
|
|
||||||
if self.use_output:
|
if self.use_output:
|
||||||
output_shape = output_shape if output_shape is not None else query.shape
|
output_shape = output_shape if output_shape is not None else query.shape
|
||||||
output = torch.zeros(output_shape, dtype=output_dtype, device=query.device)
|
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||||
hidden_size = output_shape[-1]
|
hidden_size = output_shape[-1]
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||||
@@ -705,7 +705,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
self.calc_kv_scales(q, kv_c_normed, k_pe)
|
self.calc_kv_scales(q, kv_c_normed, k_pe)
|
||||||
|
|
||||||
if self.attn_backend.accept_output_buffer:
|
if self.attn_backend.accept_output_buffer:
|
||||||
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
|
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||||
self.impl.forward(
|
self.impl.forward(
|
||||||
self,
|
self,
|
||||||
q,
|
q,
|
||||||
@@ -722,7 +722,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.attn_backend.accept_output_buffer:
|
if self.attn_backend.accept_output_buffer:
|
||||||
output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
|
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||||
torch.ops.vllm.unified_mla_attention_with_output(
|
torch.ops.vllm.unified_mla_attention_with_output(
|
||||||
q,
|
q,
|
||||||
kv_c_normed,
|
kv_c_normed,
|
||||||
|
|||||||
@@ -530,7 +530,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output
|
return output.fill_(0)
|
||||||
|
|
||||||
attn_type = self.attn_type
|
attn_type = self.attn_type
|
||||||
|
|
||||||
|
|||||||
@@ -857,7 +857,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output
|
return output.fill_(0)
|
||||||
|
|
||||||
if self.bmm1_scale is None:
|
if self.bmm1_scale is None:
|
||||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||||
|
|||||||
@@ -767,7 +767,7 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output
|
return output.fill_(0)
|
||||||
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
|
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
|
||||||
# return torch.empty_like(query)
|
# return torch.empty_like(query)
|
||||||
|
|
||||||
|
|||||||
@@ -485,7 +485,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output
|
return output.fill_(0)
|
||||||
|
|
||||||
# IMPORTANT!
|
# IMPORTANT!
|
||||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
|||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output
|
return output.fill_(0)
|
||||||
|
|
||||||
assert attn_metadata.use_cascade is False
|
assert attn_metadata.use_cascade is False
|
||||||
|
|
||||||
|
|||||||
@@ -299,7 +299,7 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output
|
return output.fill_(0)
|
||||||
|
|
||||||
assert attn_metadata.use_cascade is False
|
assert attn_metadata.use_cascade is False
|
||||||
|
|
||||||
|
|||||||
@@ -379,7 +379,7 @@ class TreeAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output
|
return output.fill_(0)
|
||||||
|
|
||||||
# Cache the input KVs.
|
# Cache the input KVs.
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output
|
return output.fill_(0)
|
||||||
|
|
||||||
assert attn_metadata.use_cascade is False
|
assert attn_metadata.use_cascade is False
|
||||||
|
|
||||||
|
|||||||
@@ -354,7 +354,7 @@ class XFormersAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Profiling run.
|
||||||
return output
|
return output.fill_(0)
|
||||||
|
|
||||||
# Cache the input KVs.
|
# Cache the input KVs.
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user