[core] gemma2 full context length support (#10584)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-22 20:13:54 -08:00
committed by GitHub
parent 978b39744b
commit 4aba6e3d1a
4 changed files with 54 additions and 23 deletions

View File

@@ -233,15 +233,26 @@ class ModelConfig:
(self.hf_text_config.model_type in ["gemma2"]))
if (not self.disable_sliding_window and has_interleaved_attention):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)
if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)
print_warning_once(
f"{self.hf_text_config.model_type} has interleaved attention, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f"({sliding_window_len_min}).")
self.disable_sliding_window = True
print_warning_once(
f"{self.hf_text_config.model_type} has interleaved "
"attention, which is currently not supported by the "
"XFORMERS backend. Disabling sliding window and capping "
"the max length to the sliding window size "
f"({sliding_window_len_min}).")
self.disable_sliding_window = True
else:
# for a model with interleaved attention,
# the scheduler and the model treat it as full attention
# (i.e., not dropping any tokens outside the window).
# only the attention layer itself is aware of the sliding
# window, and use the window size to compute the attention.
self.hf_text_config.interleaved_sliding_window = sliding_window
delattr(self.hf_text_config, "sliding_window")
sliding_window = None
self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,