Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import types
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -11,6 +9,8 @@ from vllm.model_executor.models.bert import (
|
||||
BertMLMHead,
|
||||
SPLADESparsePooler,
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Functional test: SPLADE formula correctness (no HF download needed)
|
||||
@@ -38,8 +38,12 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
|
||||
],
|
||||
dtype=torch.long,
|
||||
)
|
||||
meta = types.SimpleNamespace(
|
||||
prompt_lens=prompt_lens_tenser, prompt_token_ids=token_ids
|
||||
meta = PoolingMetadata(
|
||||
prompt_lens=prompt_lens_tenser,
|
||||
prompt_token_ids=token_ids,
|
||||
prompt_token_ids_cpu=token_ids,
|
||||
pooling_params=[PoolingParams(task="embed")] * B,
|
||||
pooling_states=[PoolingStates() for _ in range(B)],
|
||||
)
|
||||
|
||||
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable)
|
||||
|
||||
Reference in New Issue
Block a user