[Attention][V0 Deprecation] Deprecate accept output buffer (#39125)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -216,12 +216,14 @@ def test_splitting_ops_dynamic():
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["vllm::unified_attention"],
|
||||
splitting_ops=["vllm::unified_attention_with_output"],
|
||||
)
|
||||
)
|
||||
# with inductor partition we use splitting_ops directly for
|
||||
# partition rules
|
||||
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
|
||||
assert config.compilation_config.splitting_ops == [
|
||||
"vllm::unified_attention_with_output"
|
||||
]
|
||||
|
||||
# When attn_fusion pass enabled.
|
||||
config = VllmConfig(
|
||||
@@ -281,7 +283,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=[
|
||||
"vllm::unified_attention",
|
||||
"vllm::unified_attention_with_output",
|
||||
"vllm::moe_forward",
|
||||
"vllm::moe_forward_shared",
|
||||
],
|
||||
@@ -289,7 +291,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
|
||||
)
|
||||
splitting_ops = config.compilation_config.splitting_ops
|
||||
assert splitting_ops == [
|
||||
"vllm::unified_attention",
|
||||
"vllm::unified_attention_with_output",
|
||||
"vllm::moe_forward",
|
||||
"vllm::moe_forward_shared",
|
||||
]
|
||||
|
||||
@@ -282,7 +282,7 @@ class PassConfig:
|
||||
"""
|
||||
enabled_fusions = [
|
||||
f.name[len("fuse_") :]
|
||||
for f in fields(self)
|
||||
for f in fields(self) # type: ignore[arg-type]
|
||||
if getattr(self, f.name) and f.name.startswith("fuse_")
|
||||
]
|
||||
|
||||
@@ -711,9 +711,7 @@ class CompilationConfig:
|
||||
# Attention ops; used for piecewise cudagraphs
|
||||
# Use PyTorch operator format: "namespace::name"
|
||||
_attention_ops: ClassVar[list[str]] = [
|
||||
"vllm::unified_attention",
|
||||
"vllm::unified_attention_with_output",
|
||||
"vllm::unified_mla_attention",
|
||||
"vllm::unified_mla_attention_with_output",
|
||||
"vllm::mamba_mixer2",
|
||||
"vllm::mamba_mixer",
|
||||
|
||||
@@ -354,7 +354,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
@@ -429,75 +428,62 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
if self.impl.supports_quant_query_input:
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
if output_shape is None:
|
||||
# Handle both 2D [num_tokens, hidden] and
|
||||
# 3D [num_tokens, heads, head_dim] query
|
||||
num_tokens = query.shape[0]
|
||||
output_shape = torch.Size(
|
||||
(num_tokens, self.num_heads * self.head_size_v)
|
||||
if output_shape is None:
|
||||
# Handle both 2D [num_tokens, hidden] and
|
||||
# 3D [num_tokens, heads, head_dim] query
|
||||
num_tokens = query.shape[0]
|
||||
output_shape = torch.Size((num_tokens, self.num_heads * self.head_size_v))
|
||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size_v)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
kv_cache_dummy_dep = None
|
||||
if self.use_direct_call:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size_v)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
kv_cache_dummy_dep = None
|
||||
if self.use_direct_call:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
else:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
assert self.attn_backend.forward_includes_kv_cache_update, (
|
||||
"Split KV cache update not supported when output tensor not provided."
|
||||
unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
if self.use_direct_call:
|
||||
return unified_attention(query, key, value, self.layer_name)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name
|
||||
else:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
|
||||
def calc_kv_scales(self, query, key, value):
|
||||
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
|
||||
@@ -633,35 +619,6 @@ def get_attention_context(
|
||||
return attn_metadata, attn_layer, kv_cache, layer_slot_mapping
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def unified_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention",
|
||||
op_func=unified_attention,
|
||||
fake_impl=unified_attention_fake,
|
||||
)
|
||||
|
||||
|
||||
def unified_kv_cache_update(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
|
||||
@@ -133,7 +133,7 @@ def create_cross_attention_backend(
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -494,21 +494,16 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
self.kv_cache_dtype,
|
||||
self._k_scale,
|
||||
)
|
||||
if self.attn_backend.accept_output_buffer:
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
self.forward_impl(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return self.forward_impl(
|
||||
q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
|
||||
)
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
self.forward_impl(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
self_kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update(
|
||||
kv_c_normed,
|
||||
@@ -517,25 +512,16 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
self.kv_cache_dtype,
|
||||
self._k_scale,
|
||||
)
|
||||
if self.attn_backend.accept_output_buffer:
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
torch.ops.vllm.unified_mla_attention_with_output(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
return torch.ops.vllm.unified_mla_attention(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
|
||||
torch.ops.vllm.unified_mla_attention_with_output(
|
||||
q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
@@ -544,12 +530,10 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: "MLACommonMetadata",
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
use_quant = output_scale is not None or output_block_scale is not None
|
||||
if use_quant:
|
||||
# The fusion pass has allocated output with quantized dtype
|
||||
@@ -913,43 +897,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
out.copy_(out_new) # Copy result
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_mla_attention(
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
layer_name: str,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
|
||||
# that ensures torch.compile preserves ordering between KV cache update and
|
||||
# attention forward.
|
||||
del kv_cache_dummy_dep
|
||||
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
|
||||
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def unified_mla_attention_fake(
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
layer_name: str,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_mla_attention",
|
||||
op_func=unified_mla_attention,
|
||||
mutates_args=[],
|
||||
fake_impl=unified_mla_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
def unified_mla_kv_cache_update(
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
@@ -1151,8 +1098,6 @@ CUDNN_WORKSPACE_SIZE = 12800
|
||||
|
||||
|
||||
class MLACommonBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA"
|
||||
|
||||
@@ -94,7 +94,6 @@ def basic_cache(
|
||||
class CacheOnlyAttentionBackend(AttentionBackend):
|
||||
"""Attention backend that only caches KV without computing attention."""
|
||||
|
||||
accept_output_buffer: bool = False
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
|
||||
@@ -184,7 +184,7 @@ def create_whisper_attention_backend_with_block_pooling(
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -53,10 +53,6 @@ class MultipleOf:
|
||||
class AttentionBackend(ABC):
|
||||
"""Abstract class for attention backends."""
|
||||
|
||||
# For some attention backends, we allocate an output tensor before
|
||||
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||
# makes sure the output tensor is allocated inside the cudagraph.
|
||||
accept_output_buffer: bool = False
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = [
|
||||
"auto",
|
||||
@@ -779,7 +775,7 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -30,7 +30,6 @@ _CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM, CpuArchEnum.S3
|
||||
|
||||
|
||||
class CPUAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
@@ -267,7 +266,7 @@ class CPUAttentionBackendImpl(AttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: CPUAttentionMetadata | None,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -283,7 +282,6 @@ class CPUAttentionBackendImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
|
||||
@@ -62,7 +62,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
@@ -664,7 +663,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -683,7 +682,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
|
||||
@@ -128,7 +128,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -147,7 +147,6 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
assert self.vllm_flash_attn_version is not None, (
|
||||
"FlashAttention version not detected."
|
||||
)
|
||||
|
||||
@@ -315,7 +315,6 @@ class BatchDCPPrefillWrapper:
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
@@ -1286,7 +1285,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -1303,8 +1302,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
@@ -73,7 +73,6 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
|
||||
|
||||
|
||||
class FlexAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
@@ -992,7 +991,7 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlexAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -1008,7 +1007,6 @@ class FlexAttentionImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for FlexAttentionImpl"
|
||||
|
||||
@@ -59,7 +59,6 @@ class FlashInferMLASparseBackend(AttentionBackend):
|
||||
for models like DeepSeek-V3.2 that use index-based sparse attention.
|
||||
"""
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
|
||||
@@ -78,7 +78,6 @@ structured as:
|
||||
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
|
||||
@@ -78,7 +78,6 @@ def fetch_id_to_ragged_triton(
|
||||
|
||||
|
||||
class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
|
||||
@@ -35,7 +35,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
|
||||
@@ -744,7 +744,6 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
|
||||
|
||||
class AiterFlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
@@ -1037,7 +1036,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AiterFlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -1056,8 +1055,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported "
|
||||
|
||||
@@ -24,8 +24,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
|
||||
return [MultipleOf(16)]
|
||||
@@ -143,7 +141,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -159,8 +157,6 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
|
||||
@@ -159,7 +159,6 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
|
||||
|
||||
|
||||
class RocmAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
@@ -352,7 +351,7 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -368,8 +367,6 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
|
||||
@@ -30,7 +30,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TreeAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
@@ -368,7 +367,7 @@ class TreeAttentionImpl(AttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TreeAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -384,8 +383,6 @@ class TreeAttentionImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for TreeAttentionImpl"
|
||||
|
||||
@@ -262,7 +262,6 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
@@ -504,7 +503,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TritonAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output: torch.Tensor,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -520,8 +519,6 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
|
||||
Reference in New Issue
Block a user