[Attention][V0 Deprecation] Deprecate accept output buffer (#39125)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-04-07 17:14:58 -04:00
committed by GitHub
parent 08bfedc152
commit 70406eb1dc
22 changed files with 94 additions and 227 deletions

View File

@@ -216,12 +216,14 @@ def test_splitting_ops_dynamic():
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
splitting_ops=["vllm::unified_attention"],
splitting_ops=["vllm::unified_attention_with_output"],
)
)
# with inductor partition we use splitting_ops directly for
# partition rules
assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
assert config.compilation_config.splitting_ops == [
"vllm::unified_attention_with_output"
]
# When attn_fusion pass enabled.
config = VllmConfig(
@@ -281,7 +283,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
splitting_ops=[
"vllm::unified_attention",
"vllm::unified_attention_with_output",
"vllm::moe_forward",
"vllm::moe_forward_shared",
],
@@ -289,7 +291,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
)
splitting_ops = config.compilation_config.splitting_ops
assert splitting_ops == [
"vllm::unified_attention",
"vllm::unified_attention_with_output",
"vllm::moe_forward",
"vllm::moe_forward_shared",
]

View File

@@ -282,7 +282,7 @@ class PassConfig:
"""
enabled_fusions = [
f.name[len("fuse_") :]
for f in fields(self)
for f in fields(self) # type: ignore[arg-type]
if getattr(self, f.name) and f.name.startswith("fuse_")
]
@@ -711,9 +711,7 @@ class CompilationConfig:
# Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name"
_attention_ops: ClassVar[list[str]] = [
"vllm::unified_attention",
"vllm::unified_attention_with_output",
"vllm::unified_mla_attention",
"vllm::unified_mla_attention_with_output",
"vllm::mamba_mixer2",
"vllm::mamba_mixer",

View File

@@ -354,7 +354,6 @@ class Attention(nn.Module, AttentionLayerBase):
# and let torch.compile handle them.
self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
@@ -429,75 +428,62 @@ class Attention(nn.Module, AttentionLayerBase):
if self.impl.supports_quant_query_input:
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
if output_shape is None:
# Handle both 2D [num_tokens, hidden] and
# 3D [num_tokens, heads, head_dim] query
num_tokens = query.shape[0]
output_shape = torch.Size(
(num_tokens, self.num_heads * self.head_size_v)
if output_shape is None:
# Handle both 2D [num_tokens, hidden] and
# 3D [num_tokens, heads, head_dim] query
num_tokens = query.shape[0]
output_shape = torch.Size((num_tokens, self.num_heads * self.head_size_v))
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size_v)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size_v)
kv_cache_dummy_dep = None
if self.use_direct_call:
# Skip this if sharing KV cache with an earlier attention layer.
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = unified_kv_cache_update(
key, value, self.layer_name
)
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size_v)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size_v)
kv_cache_dummy_dep = None
if self.use_direct_call:
# Skip this if sharing KV cache with an earlier attention layer.
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = unified_kv_cache_update(
key, value, self.layer_name
)
unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
else:
# Skip this if sharing KV cache with an earlier attention layer.
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
key, value, self.layer_name
)
torch.ops.vllm.unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output.view(-1, hidden_size)
else:
assert self.attn_backend.forward_includes_kv_cache_update, (
"Split KV cache update not supported when output tensor not provided."
unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
if self.use_direct_call:
return unified_attention(query, key, value, self.layer_name)
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name
else:
# Skip this if sharing KV cache with an earlier attention layer.
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
key, value, self.layer_name
)
torch.ops.vllm.unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output.view(-1, hidden_size)
def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
@@ -633,35 +619,6 @@ def get_attention_context(
return attn_metadata, attn_layer, kv_cache, layer_slot_mapping
@maybe_transfer_kv_layer
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
return output
def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
fake_impl=unified_attention_fake,
)
def unified_kv_cache_update(
key: torch.Tensor,
value: torch.Tensor,

View File

@@ -133,7 +133,7 @@ def create_cross_attention_backend(
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:

View File

@@ -494,21 +494,16 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.kv_cache_dtype,
self._k_scale,
)
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.forward_impl(
q,
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata,
output=output,
)
return output
else:
return self.forward_impl(
q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
)
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
self.forward_impl(
q,
kv_c_normed,
k_pe,
self_kv_cache,
attn_metadata,
output=output,
)
return output
else:
kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update(
kv_c_normed,
@@ -517,25 +512,16 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.kv_cache_dtype,
self._k_scale,
)
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
q,
kv_c_normed,
k_pe,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output
else:
return torch.ops.vllm.unified_mla_attention(
q,
kv_c_normed,
k_pe,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
q,
kv_c_normed,
k_pe,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output
def forward_impl(
self,
@@ -544,12 +530,10 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: "MLACommonMetadata",
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
use_quant = output_scale is not None or output_block_scale is not None
if use_quant:
# The fusion pass has allocated output with quantized dtype
@@ -913,43 +897,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
out.copy_(out_new) # Copy result
@maybe_transfer_kv_layer
def unified_mla_attention(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> torch.Tensor:
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del kv_cache_dummy_dep
attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
return output
def unified_mla_attention_fake(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()
direct_register_custom_op(
op_name="unified_mla_attention",
op_func=unified_mla_attention,
mutates_args=[],
fake_impl=unified_mla_attention_fake,
dispatch_key=current_platform.dispatch_key,
)
def unified_mla_kv_cache_update(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
@@ -1151,8 +1098,6 @@ CUDNN_WORKSPACE_SIZE = 12800
class MLACommonBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "TRITON_MLA"

View File

@@ -94,7 +94,6 @@ def basic_cache(
class CacheOnlyAttentionBackend(AttentionBackend):
"""Attention backend that only caches KV without computing attention."""
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,

View File

@@ -184,7 +184,7 @@ def create_whisper_attention_backend_with_block_pooling(
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:

View File

@@ -53,10 +53,6 @@ class MultipleOf:
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = [
"auto",
@@ -779,7 +775,7 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:

View File

@@ -30,7 +30,6 @@ _CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM, CpuArchEnum.S3
class CPUAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
@@ -267,7 +266,7 @@ class CPUAttentionBackendImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: CPUAttentionMetadata | None,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -283,7 +282,6 @@ class CPUAttentionBackendImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"

View File

@@ -62,7 +62,6 @@ logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
@@ -664,7 +663,7 @@ class FlashAttentionImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -683,7 +682,6 @@ class FlashAttentionImpl(AttentionImpl):
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)

View File

@@ -128,7 +128,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -147,7 +147,6 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)

View File

@@ -315,7 +315,6 @@ class BatchDCPPrefillWrapper:
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
@@ -1286,7 +1285,7 @@ class FlashInferImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -1303,8 +1302,6 @@ class FlashInferImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output.fill_(0)

View File

@@ -73,7 +73,6 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
@@ -992,7 +991,7 @@ class FlexAttentionImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlexAttentionMetadata,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -1008,7 +1007,6 @@ class FlexAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for FlexAttentionImpl"

View File

@@ -59,7 +59,6 @@ class FlashInferMLASparseBackend(AttentionBackend):
for models like DeepSeek-V3.2 that use index-based sparse attention.
"""
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",

View File

@@ -78,7 +78,6 @@ structured as:
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",

View File

@@ -78,7 +78,6 @@ def fetch_id_to_ragged_triton(
class ROCMAiterMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",

View File

@@ -35,7 +35,6 @@ logger = init_logger(__name__)
class XPUMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",

View File

@@ -744,7 +744,6 @@ class AiterFlashAttentionMetadataBuilder(
class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
@@ -1037,7 +1036,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AiterFlashAttentionMetadata,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -1056,8 +1055,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported "

View File

@@ -24,8 +24,6 @@ logger = init_logger(__name__)
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]
@@ -143,7 +141,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -159,8 +157,6 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"

View File

@@ -159,7 +159,6 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
class RocmAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
@@ -352,7 +351,7 @@ class RocmAttentionImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -368,8 +367,6 @@ class RocmAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"

View File

@@ -30,7 +30,6 @@ logger = init_logger(__name__)
class TreeAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
@@ -368,7 +367,7 @@ class TreeAttentionImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TreeAttentionMetadata,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -384,8 +383,6 @@ class TreeAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported for TreeAttentionImpl"

View File

@@ -262,7 +262,6 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet
class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [
torch.float16,
torch.bfloat16,
@@ -504,7 +503,7 @@ class TritonAttentionImpl(AttentionImpl):
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TritonAttentionMetadata,
output: torch.Tensor | None = None,
output: torch.Tensor,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
@@ -520,8 +519,6 @@ class TritonAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"