Refactor sliding window configuration to Transformers best practice (#21927)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -200,28 +200,6 @@ def test_disable_sliding_window(model_id_expected):
|
||||
assert model_config.max_model_len == expected
|
||||
|
||||
|
||||
def test_get_sliding_window():
|
||||
TEST_SLIDING_WINDOW = 4096
|
||||
# Test that the sliding window is correctly computed.
|
||||
# For Qwen1.5/Qwen2, get_sliding_window() should be None
|
||||
# when use_sliding_window is False.
|
||||
qwen2_model_config = ModelConfig("Qwen/Qwen1.5-7B")
|
||||
|
||||
qwen2_model_config.hf_config.use_sliding_window = False
|
||||
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
||||
assert qwen2_model_config.get_sliding_window() is None
|
||||
|
||||
qwen2_model_config.hf_config.use_sliding_window = True
|
||||
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||
|
||||
mistral_model_config = ModelConfig("mistralai/Mistral-7B-v0.1")
|
||||
mistral_model_config.hf_config.sliding_window = None
|
||||
assert mistral_model_config.get_sliding_window() is None
|
||||
|
||||
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
||||
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_pooling_config():
|
||||
|
||||
Reference in New Issue
Block a user