From 70406eb1dc191fa693258452c561855dbdf211a5 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 7 Apr 2026 17:14:58 -0400 Subject: [PATCH] [Attention][V0 Deprecation] Deprecate accept output buffer (#39125) Signed-off-by: Lucas Wilkinson --- tests/compile/test_config.py | 10 +- vllm/config/compilation.py | 4 +- .../layers/attention/attention.py | 149 +++++++----------- .../layers/attention/cross_attention.py | 2 +- .../layers/attention/mla_attention.py | 97 +++--------- .../models/extract_hidden_states.py | 1 - vllm/model_executor/models/whisper_causal.py | 2 +- vllm/v1/attention/backend.py | 6 +- vllm/v1/attention/backends/cpu_attn.py | 4 +- vllm/v1/attention/backends/flash_attn.py | 4 +- .../attention/backends/flash_attn_diffkv.py | 3 +- vllm/v1/attention/backends/flashinfer.py | 5 +- vllm/v1/attention/backends/flex_attention.py | 4 +- .../backends/mla/flashinfer_mla_sparse.py | 1 - .../attention/backends/mla/flashmla_sparse.py | 1 - .../backends/mla/rocm_aiter_mla_sparse.py | 1 - .../attention/backends/mla/xpu_mla_sparse.py | 1 - vllm/v1/attention/backends/rocm_aiter_fa.py | 5 +- .../backends/rocm_aiter_unified_attn.py | 6 +- vllm/v1/attention/backends/rocm_attn.py | 5 +- vllm/v1/attention/backends/tree_attn.py | 5 +- vllm/v1/attention/backends/triton_attn.py | 5 +- 22 files changed, 94 insertions(+), 227 deletions(-) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index d5808955b..c0f0dcca8 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -216,12 +216,14 @@ def test_splitting_ops_dynamic(): compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, use_inductor_graph_partition=True, - splitting_ops=["vllm::unified_attention"], + splitting_ops=["vllm::unified_attention_with_output"], ) ) # with inductor partition we use splitting_ops directly for # partition rules - assert config.compilation_config.splitting_ops == ["vllm::unified_attention"] + assert config.compilation_config.splitting_ops == [ + "vllm::unified_attention_with_output" + ] # When attn_fusion pass enabled. config = VllmConfig( @@ -281,7 +283,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition(): mode=CompilationMode.VLLM_COMPILE, use_inductor_graph_partition=True, splitting_ops=[ - "vllm::unified_attention", + "vllm::unified_attention_with_output", "vllm::moe_forward", "vllm::moe_forward_shared", ], @@ -289,7 +291,7 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition(): ) splitting_ops = config.compilation_config.splitting_ops assert splitting_ops == [ - "vllm::unified_attention", + "vllm::unified_attention_with_output", "vllm::moe_forward", "vllm::moe_forward_shared", ] diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 716c208a9..ef2a4bf5a 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -282,7 +282,7 @@ class PassConfig: """ enabled_fusions = [ f.name[len("fuse_") :] - for f in fields(self) + for f in fields(self) # type: ignore[arg-type] if getattr(self, f.name) and f.name.startswith("fuse_") ] @@ -711,9 +711,7 @@ class CompilationConfig: # Attention ops; used for piecewise cudagraphs # Use PyTorch operator format: "namespace::name" _attention_ops: ClassVar[list[str]] = [ - "vllm::unified_attention", "vllm::unified_attention_with_output", - "vllm::unified_mla_attention", "vllm::unified_mla_attention_with_output", "vllm::mamba_mixer2", "vllm::mamba_mixer", diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index a49415a3d..27cc21191 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -354,7 +354,6 @@ class Attention(nn.Module, AttentionLayerBase): # and let torch.compile handle them. self.use_direct_call = not current_platform.opaque_attention_op() - self.use_output = self.attn_backend.accept_output_buffer compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") @@ -429,75 +428,62 @@ class Attention(nn.Module, AttentionLayerBase): if self.impl.supports_quant_query_input: query, _ = self.query_quant(query, self._q_scale) - if self.use_output: - if output_shape is None: - # Handle both 2D [num_tokens, hidden] and - # 3D [num_tokens, heads, head_dim] query - num_tokens = query.shape[0] - output_shape = torch.Size( - (num_tokens, self.num_heads * self.head_size_v) + if output_shape is None: + # Handle both 2D [num_tokens, hidden] and + # 3D [num_tokens, heads, head_dim] query + num_tokens = query.shape[0] + output_shape = torch.Size((num_tokens, self.num_heads * self.head_size_v)) + output = torch.empty(output_shape, dtype=output_dtype, device=query.device) + hidden_size = output_shape[-1] + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size_v) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size_v) + kv_cache_dummy_dep = None + if self.use_direct_call: + # Skip this if sharing KV cache with an earlier attention layer. + if ( + not self.attn_backend.forward_includes_kv_cache_update + and self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + kv_cache_dummy_dep = unified_kv_cache_update( + key, value, self.layer_name ) - output = torch.empty(output_shape, dtype=output_dtype, device=query.device) - hidden_size = output_shape[-1] - # Reshape the query, key, and value tensors. - # NOTE(woosuk): We do this outside the custom op to minimize the - # CPU overheads from the non-CUDA-graph regions. - query = query.view(-1, self.num_heads, self.head_size) - output = output.view(-1, self.num_heads, self.head_size_v) - if key is not None: - key = key.view(-1, self.num_kv_heads, self.head_size) - if value is not None: - value = value.view(-1, self.num_kv_heads, self.head_size_v) - kv_cache_dummy_dep = None - if self.use_direct_call: - # Skip this if sharing KV cache with an earlier attention layer. - if ( - not self.attn_backend.forward_includes_kv_cache_update - and self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - kv_cache_dummy_dep = unified_kv_cache_update( - key, value, self.layer_name - ) - unified_attention_with_output( - query, - key, - value, - output, - self.layer_name, - kv_cache_dummy_dep=kv_cache_dummy_dep, - ) - else: - # Skip this if sharing KV cache with an earlier attention layer. - if ( - not self.attn_backend.forward_includes_kv_cache_update - and self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update( - key, value, self.layer_name - ) - torch.ops.vllm.unified_attention_with_output( - query, - key, - value, - output, - self.layer_name, - kv_cache_dummy_dep=kv_cache_dummy_dep, - ) - return output.view(-1, hidden_size) - else: - assert self.attn_backend.forward_includes_kv_cache_update, ( - "Split KV cache update not supported when output tensor not provided." + unified_attention_with_output( + query, + key, + value, + output, + self.layer_name, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) - if self.use_direct_call: - return unified_attention(query, key, value, self.layer_name) - else: - return torch.ops.vllm.unified_attention( - query, key, value, self.layer_name + else: + # Skip this if sharing KV cache with an earlier attention layer. + if ( + not self.attn_backend.forward_includes_kv_cache_update + and self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update( + key, value, self.layer_name ) + torch.ops.vllm.unified_attention_with_output( + query, + key, + value, + output, + self.layer_name, + kv_cache_dummy_dep=kv_cache_dummy_dep, + ) + return output.view(-1, hidden_size) def calc_kv_scales(self, query, key, value): self._q_scale.copy_(torch.abs(query).max() / self.q_range) @@ -633,35 +619,6 @@ def get_attention_context( return attn_metadata, attn_layer, kv_cache, layer_slot_mapping -@maybe_transfer_kv_layer -def unified_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - layer_name: str, -) -> torch.Tensor: - attn_metadata, self, kv_cache, _ = get_attention_context(layer_name) - output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata) - - return output - - -def unified_attention_fake( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - layer_name: str, -) -> torch.Tensor: - return torch.empty_like(query).contiguous() - - -direct_register_custom_op( - op_name="unified_attention", - op_func=unified_attention, - fake_impl=unified_attention_fake, -) - - def unified_kv_cache_update( key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/model_executor/layers/attention/cross_attention.py b/vllm/model_executor/layers/attention/cross_attention.py index 31ac7fa1b..61699832a 100644 --- a/vllm/model_executor/layers/attention/cross_attention.py +++ b/vllm/model_executor/layers/attention/cross_attention.py @@ -133,7 +133,7 @@ def create_cross_attention_backend( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 699238b48..1d046b16e 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -494,21 +494,16 @@ class MLAAttention(nn.Module, AttentionLayerBase): self.kv_cache_dtype, self._k_scale, ) - if self.attn_backend.accept_output_buffer: - output = torch.empty(output_shape, dtype=q.dtype, device=q.device) - self.forward_impl( - q, - kv_c_normed, - k_pe, - self_kv_cache, - attn_metadata, - output=output, - ) - return output - else: - return self.forward_impl( - q, kv_c_normed, k_pe, self_kv_cache, attn_metadata - ) + output = torch.empty(output_shape, dtype=q.dtype, device=q.device) + self.forward_impl( + q, + kv_c_normed, + k_pe, + self_kv_cache, + attn_metadata, + output=output, + ) + return output else: kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update( kv_c_normed, @@ -517,25 +512,16 @@ class MLAAttention(nn.Module, AttentionLayerBase): self.kv_cache_dtype, self._k_scale, ) - if self.attn_backend.accept_output_buffer: - output = torch.empty(output_shape, dtype=q.dtype, device=q.device) - torch.ops.vllm.unified_mla_attention_with_output( - q, - kv_c_normed, - k_pe, - output, - self.layer_name, - kv_cache_dummy_dep=kv_cache_dummy_dep, - ) - return output - else: - return torch.ops.vllm.unified_mla_attention( - q, - kv_c_normed, - k_pe, - self.layer_name, - kv_cache_dummy_dep=kv_cache_dummy_dep, - ) + output = torch.empty(output_shape, dtype=q.dtype, device=q.device) + torch.ops.vllm.unified_mla_attention_with_output( + q, + kv_c_normed, + k_pe, + output, + self.layer_name, + kv_cache_dummy_dep=kv_cache_dummy_dep, + ) + return output def forward_impl( self, @@ -544,12 +530,10 @@ class MLAAttention(nn.Module, AttentionLayerBase): k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: "MLACommonMetadata", - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." - use_quant = output_scale is not None or output_block_scale is not None if use_quant: # The fusion pass has allocated output with quantized dtype @@ -913,43 +897,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): out.copy_(out_new) # Copy result -@maybe_transfer_kv_layer -def unified_mla_attention( - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - layer_name: str, - kv_cache_dummy_dep: torch.Tensor | None = None, -) -> torch.Tensor: - # kv_cache_dummy_dep is not used but accepting it creates a data dependency - # that ensures torch.compile preserves ordering between KV cache update and - # attention forward. - del kv_cache_dummy_dep - attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) - output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata) - - return output - - -def unified_mla_attention_fake( - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - layer_name: str, - kv_cache_dummy_dep: torch.Tensor | None = None, -) -> torch.Tensor: - return torch.empty_like(q).contiguous() - - -direct_register_custom_op( - op_name="unified_mla_attention", - op_func=unified_mla_attention, - mutates_args=[], - fake_impl=unified_mla_attention_fake, - dispatch_key=current_platform.dispatch_key, -) - - def unified_mla_kv_cache_update( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, @@ -1151,8 +1098,6 @@ CUDNN_WORKSPACE_SIZE = 12800 class MLACommonBackend(AttentionBackend): - accept_output_buffer: bool = True - @staticmethod def get_name() -> str: return "TRITON_MLA" diff --git a/vllm/model_executor/models/extract_hidden_states.py b/vllm/model_executor/models/extract_hidden_states.py index 3f1e7e693..e2d39ead6 100644 --- a/vllm/model_executor/models/extract_hidden_states.py +++ b/vllm/model_executor/models/extract_hidden_states.py @@ -94,7 +94,6 @@ def basic_cache( class CacheOnlyAttentionBackend(AttentionBackend): """Attention backend that only caches KV without computing attention.""" - accept_output_buffer: bool = False supported_dtypes: ClassVar[list[torch.dtype]] = [ torch.float16, torch.bfloat16, diff --git a/vllm/model_executor/models/whisper_causal.py b/vllm/model_executor/models/whisper_causal.py index 8e4322ea3..dfbf69418 100644 --- a/vllm/model_executor/models/whisper_causal.py +++ b/vllm/model_executor/models/whisper_causal.py @@ -184,7 +184,7 @@ def create_whisper_attention_backend_with_block_pooling( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index bb05b31bb..28d077fcb 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -53,10 +53,6 @@ class MultipleOf: class AttentionBackend(ABC): """Abstract class for attention backends.""" - # For some attention backends, we allocate an output tensor before - # calling the custom op. When piecewise cudagraph is enabled, this - # makes sure the output tensor is allocated inside the cudagraph. - accept_output_buffer: bool = False supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = [ "auto", @@ -779,7 +775,7 @@ class AttentionImpl(AttentionImplBase[T], Generic[T]): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 5216301ef..76c78fed4 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -30,7 +30,6 @@ _CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86, CpuArchEnum.ARM, CpuArchEnum.S3 class CPUAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [ torch.float16, torch.bfloat16, @@ -267,7 +266,7 @@ class CPUAttentionBackendImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: CPUAttentionMetadata | None, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: @@ -283,7 +282,6 @@ class CPUAttentionBackendImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ - assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d72c2aeb6..5e81cae42 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -62,7 +62,6 @@ logger = init_logger(__name__) class FlashAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", @@ -664,7 +663,7 @@ class FlashAttentionImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: @@ -683,7 +682,6 @@ class FlashAttentionImpl(AttentionImpl): {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values """ - assert output is not None, "Output tensor must be provided." assert self.vllm_flash_attn_version is not None, ( "FlashAttention version not detected." ) diff --git a/vllm/v1/attention/backends/flash_attn_diffkv.py b/vllm/v1/attention/backends/flash_attn_diffkv.py index 2a8ab86af..cf98d3d09 100644 --- a/vllm/v1/attention/backends/flash_attn_diffkv.py +++ b/vllm/v1/attention/backends/flash_attn_diffkv.py @@ -128,7 +128,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: @@ -147,7 +147,6 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl): {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values """ - assert output is not None, "Output tensor must be provided." assert self.vllm_flash_attn_version is not None, ( "FlashAttention version not detected." ) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index e2f9f2b8c..3543bdabb 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -315,7 +315,6 @@ class BatchDCPPrefillWrapper: class FlashInferBackend(AttentionBackend): - accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", @@ -1286,7 +1285,7 @@ class FlashInferImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: @@ -1303,8 +1302,6 @@ class FlashInferImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ - assert output is not None, "Output tensor must be provided." - if attn_metadata is None: # Profiling run. return output.fill_(0) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 5e202e00f..b481a1f16 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -73,7 +73,6 @@ def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): class FlexAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [ torch.float16, torch.bfloat16, @@ -992,7 +991,7 @@ class FlexAttentionImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlexAttentionMetadata, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: @@ -1008,7 +1007,6 @@ class FlexAttentionImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ - assert output is not None, "Output tensor must be provided." if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported for FlexAttentionImpl" diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py index 1eb12f72e..842153f40 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py @@ -59,7 +59,6 @@ class FlashInferMLASparseBackend(AttentionBackend): for models like DeepSeek-V3.2 that use index-based sparse attention. """ - accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 816ad88a8..1d981717c 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -78,7 +78,6 @@ structured as: class FlashMLASparseBackend(AttentionBackend): - accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index f14271d1b..f24aa6055 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -78,7 +78,6 @@ def fetch_id_to_ragged_triton( class ROCMAiterMLASparseBackend(AttentionBackend): - accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", diff --git a/vllm/v1/attention/backends/mla/xpu_mla_sparse.py b/vllm/v1/attention/backends/mla/xpu_mla_sparse.py index 59ec42e93..2fa91d018 100644 --- a/vllm/v1/attention/backends/mla/xpu_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/xpu_mla_sparse.py @@ -35,7 +35,6 @@ logger = init_logger(__name__) class XPUMLASparseBackend(AttentionBackend): - accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index a897b33fc..abbbcafb5 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -744,7 +744,6 @@ class AiterFlashAttentionMetadataBuilder( class AiterFlashAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", @@ -1037,7 +1036,7 @@ class AiterFlashAttentionImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AiterFlashAttentionMetadata, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: @@ -1056,8 +1055,6 @@ class AiterFlashAttentionImpl(AttentionImpl): {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values """ - assert output is not None, "Output tensor must be provided." - if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported " diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index c91f8a225..eb0fe046e 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -24,8 +24,6 @@ logger = init_logger(__name__) class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): - accept_output_buffer: bool = True - @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: return [MultipleOf(16)] @@ -143,7 +141,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: @@ -159,8 +157,6 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ - assert output is not None, "Output tensor must be provided." - if output_block_scale is not None: raise NotImplementedError( "fused block_scale output quantization is not yet supported" diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index a8448c489..3a9062332 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -159,7 +159,6 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat class RocmAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [ torch.float16, torch.bfloat16, @@ -352,7 +351,7 @@ class RocmAttentionImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: @@ -368,8 +367,6 @@ class RocmAttentionImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ - assert output is not None, "Output tensor must be provided." - if output_block_scale is not None: raise NotImplementedError( "fused block_scale output quantization is not yet supported" diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 587f71628..ceee8d549 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -30,7 +30,6 @@ logger = init_logger(__name__) class TreeAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", @@ -368,7 +367,7 @@ class TreeAttentionImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TreeAttentionMetadata, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: @@ -384,8 +383,6 @@ class TreeAttentionImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ - assert output is not None, "Output tensor must be provided." - if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported for TreeAttentionImpl" diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 5b1eec385..bd8ec29bc 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -262,7 +262,6 @@ class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMet class TritonAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True supported_dtypes: ClassVar[list[torch.dtype]] = [ torch.float16, torch.bfloat16, @@ -504,7 +503,7 @@ class TritonAttentionImpl(AttentionImpl): value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TritonAttentionMetadata, - output: torch.Tensor | None = None, + output: torch.Tensor, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: @@ -520,8 +519,6 @@ class TritonAttentionImpl(AttentionImpl): Returns: shape = [num_tokens, num_heads * head_size] """ - assert output is not None, "Output tensor must be provided." - if output_block_scale is not None: raise NotImplementedError( "fused block_scale output quantization is not yet supported"