[v1] fix compilation cache (#11598)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-12-30 12:24:12 +08:00
committed by GitHub
parent 0aa38d16f5
commit 3682e33f9f
4 changed files with 69 additions and 14 deletions

View File

@@ -9,8 +9,8 @@ from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
Final, List, Literal, Mapping, Optional, Set, Tuple, Type,
Union)
Final, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, Type, Union)
import torch
from pydantic import BaseModel, Field, PrivateAttr
@@ -75,6 +75,12 @@ HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]]
class SupportsHash(Protocol):
def compute_hash(self) -> str:
...
class ModelConfig:
"""Configuration for the model.
@@ -2969,6 +2975,10 @@ class VllmConfig:
init=True) # type: ignore
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing and debugging.
additional_config: SupportsHash = field(default=None,
init=True) # type: ignore
instance_id: str = ""
def compute_hash(self) -> str:
@@ -3000,33 +3010,62 @@ class VllmConfig:
vllm_factors.append(__version__)
if self.model_config:
vllm_factors.append(self.model_config.compute_hash())
else:
vllm_factors.append("None")
if self.cache_config:
vllm_factors.append(self.cache_config.compute_hash())
else:
vllm_factors.append("None")
if self.parallel_config:
vllm_factors.append(self.parallel_config.compute_hash())
else:
vllm_factors.append("None")
if self.scheduler_config:
vllm_factors.append(self.scheduler_config.compute_hash())
else:
vllm_factors.append("None")
if self.device_config:
vllm_factors.append(self.device_config.compute_hash())
else:
vllm_factors.append("None")
if self.load_config:
vllm_factors.append(self.load_config.compute_hash())
else:
vllm_factors.append("None")
if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash())
else:
vllm_factors.append("None")
if self.speculative_config:
vllm_factors.append(self.speculative_config.compute_hash())
else:
vllm_factors.append("None")
if self.decoding_config:
vllm_factors.append(self.decoding_config.compute_hash())
else:
vllm_factors.append("None")
if self.observability_config:
vllm_factors.append(self.observability_config.compute_hash())
else:
vllm_factors.append("None")
if self.prompt_adapter_config:
vllm_factors.append(self.prompt_adapter_config.compute_hash())
else:
vllm_factors.append("None")
if self.quant_config:
pass # should be captured by model_config.quantization
if self.compilation_config:
vllm_factors.append(self.compilation_config.compute_hash())
else:
vllm_factors.append("None")
if self.kv_transfer_config:
vllm_factors.append(self.kv_transfer_config.compute_hash())
else:
vllm_factors.append("None")
if self.additional_config:
vllm_factors.append(self.additional_config.compute_hash())
else:
vllm_factors.append("None")
factors.append(vllm_factors)
hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]