[Bugfix] Fix Basic Models Test (#34818)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2026-02-19 17:49:07 -05:00
committed by GitHub
parent 4fb8beefaa
commit 662205d34e
14 changed files with 175 additions and 221 deletions

View File

@@ -407,17 +407,24 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
# Attributes for forward_impl method
self.chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
get_current_vllm_config()
)
)
self._vllm_config = get_current_vllm_config()
self._chunked_prefill_workspace_size: int | None = None
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
compile_native=True,
)
@property
def chunked_prefill_workspace_size(self) -> int:
if self._chunked_prefill_workspace_size is None:
self._chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
self._vllm_config
)
)
return self._chunked_prefill_workspace_size
def forward(
self,
q: torch.Tensor,