Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support (#11844)
This commit is contained in:
@@ -217,6 +217,39 @@ def get_quant_config(model_config: ModelConfig,
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
|
||||
def get_sparse_attention_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
sparse_attention_config_filename: str = "sparse_attention_config.json",
|
||||
) -> Dict[str, Any]:
|
||||
model_name_or_path = model_config.model
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, load_config.download_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
tqdm_class=DisabledTqdm,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
config_file = os.path.join(hf_folder, sparse_attention_config_filename)
|
||||
if not os.path.exists(config_file):
|
||||
return {}
|
||||
|
||||
# Load the sparse attention config.
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
logger.info("Loaded sparse attention config from %s", config_file)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def download_weights_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
|
||||
Reference in New Issue
Block a user