[CI] Revert PRs 34818 and 33600 (#34979)
This commit is contained in:
@@ -30,8 +30,9 @@ from vllm.v1.kv_cache_interface import (
|
||||
def create_chunked_local_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
attention_chunk_size: int,
|
||||
block_size: int,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_"
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
assert issubclass(underlying_builder, AttentionMetadataBuilder)
|
||||
@@ -54,9 +55,7 @@ def create_chunked_local_attention_backend(
|
||||
fast_build: bool = False,
|
||||
):
|
||||
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
|
||||
attention_chunk_size,
|
||||
common_attn_metadata,
|
||||
self.kv_cache_spec.block_size,
|
||||
attention_chunk_size, common_attn_metadata, block_size
|
||||
)
|
||||
metadata = super().build(common_prefix_len, cm, fast_build)
|
||||
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
|
||||
@@ -98,13 +97,13 @@ class ChunkedLocalAttention(Attention):
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = None
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
attn_backend = create_chunked_local_attention_backend(
|
||||
underlying_attn_backend, attention_chunk_size
|
||||
underlying_attn_backend, attention_chunk_size, block_size
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
|
||||
@@ -407,24 +407,17 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
)
|
||||
|
||||
# Attributes for forward_impl method
|
||||
self._vllm_config = get_current_vllm_config()
|
||||
self._chunked_prefill_workspace_size: int | None = None
|
||||
self.chunked_prefill_workspace_size = (
|
||||
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
|
||||
get_current_vllm_config()
|
||||
)
|
||||
)
|
||||
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
compile_native=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def chunked_prefill_workspace_size(self) -> int:
|
||||
if self._chunked_prefill_workspace_size is None:
|
||||
self._chunked_prefill_workspace_size = (
|
||||
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
|
||||
self._vllm_config
|
||||
)
|
||||
)
|
||||
return self._chunked_prefill_workspace_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user