[TPU] Temporary fix vmem oom for long model len by reducing page size (#20278)
Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
@@ -86,6 +86,12 @@ class PallasAttentionBackend(AttentionBackend):
|
|||||||
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
|
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_page_size(vllm_config: VllmConfig) -> int:
|
def get_page_size(vllm_config: VllmConfig) -> int:
|
||||||
|
# TODO: This is a temporary fix for vmem OOM.
|
||||||
|
# For long model length, we use 16 page-size to avoid too much
|
||||||
|
# VMEM spill. A more robust solution should be implemented to
|
||||||
|
# handle VREG spills.
|
||||||
|
if vllm_config.model_config.max_model_len > 8192:
|
||||||
|
return 16
|
||||||
page_size = next_power_of_2(
|
page_size = next_power_of_2(
|
||||||
vllm_config.model_config.max_model_len) // 16
|
vllm_config.model_config.max_model_len) // 16
|
||||||
if page_size <= 16:
|
if page_size <= 16:
|
||||||
|
|||||||
Reference in New Issue
Block a user