[CI] Fix SPLADE pooler test broken by #38139 (#38495)

Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
haosdent
2026-03-30 15:48:33 +08:00
committed by GitHub
parent 85c0950b1f
commit a08b7733fd

View File

@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types
import pytest
import torch
import torch.nn as nn
@@ -11,6 +9,8 @@ from vllm.model_executor.models.bert import (
BertMLMHead,
SPLADESparsePooler,
)
from vllm.pooling_params import PoolingParams
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
# ---------------------------------------------------------------------
# Functional test: SPLADE formula correctness (no HF download needed)
@@ -38,8 +38,12 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
],
dtype=torch.long,
)
meta = types.SimpleNamespace(
prompt_lens=prompt_lens_tenser, prompt_token_ids=token_ids
meta = PoolingMetadata(
prompt_lens=prompt_lens_tenser,
prompt_token_ids=token_ids,
prompt_token_ids_cpu=token_ids,
pooling_params=[PoolingParams(task="embed")] * B,
pooling_states=[PoolingStates() for _ in range(B)],
)
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable)