Refactor sliding window configuration to Transformers best practice (#21927)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-08-10 04:50:48 +01:00
committed by GitHub
parent 2a84fb422f
commit c49848396d
16 changed files with 123 additions and 231 deletions

View File

@@ -200,28 +200,6 @@ def test_disable_sliding_window(model_id_expected):
assert model_config.max_model_len == expected
def test_get_sliding_window():
TEST_SLIDING_WINDOW = 4096
# Test that the sliding window is correctly computed.
# For Qwen1.5/Qwen2, get_sliding_window() should be None
# when use_sliding_window is False.
qwen2_model_config = ModelConfig("Qwen/Qwen1.5-7B")
qwen2_model_config.hf_config.use_sliding_window = False
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert qwen2_model_config.get_sliding_window() is None
qwen2_model_config.hf_config.use_sliding_window = True
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
mistral_model_config = ModelConfig("mistralai/Mistral-7B-v0.1")
mistral_model_config.hf_config.sliding_window = None
assert mistral_model_config.get_sliding_window() is None
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():