Source code for deepextractor.utils.checkpoints

import os
import logging

import torch

[docs] logger = logging.getLogger(__name__)
# Hugging Face Hub repo that hosts pretrained weights.
[docs] HF_REPO_ID = "tomdooney/deepextractor"
# Named checkpoint constants — use these instead of bare filename strings.
[docs] CHECKPOINT_BILBY = "checkpoint_best_bilby_noise_base.pth.tar" # trained on simulated bilby noise
[docs] CHECKPOINT_REAL = "checkpoint_best_real_noise_base.pth.tar" # fine-tuned on real O3 LIGO data
[docs] def save_checkpoint(state, filename="my_checkpoint.pth.tar"): """Save model and optimizer state as a checkpoint.""" print("=> Saving checkpoint") torch.save(state, filename)
[docs] def load_checkpoint(checkpoint, model): """Load model state from a checkpoint.""" print("=> Loading checkpoint") try: model.load_state_dict(checkpoint["state_dict"]) except KeyError: print("=> Failed to load checkpoint: state_dict not found") raise
[docs] def load_optimizer(checkpoint, optimizer): """Load optimizer state from a checkpoint.""" print("=> Loading optimizer") try: optimizer.load_state_dict(checkpoint["optimizer"]) except KeyError: print("=> Failed to load checkpoint: optimizer state not found") raise
def _resolve_checkpoint(model_name, checkpoint_dir, filename=CHECKPOINT_BILBY): """ Return the local path to a checkpoint file. Checks ``checkpoint_dir`` first; falls back to downloading from Hugging Face Hub. """ if checkpoint_dir is not None: local_path = os.path.join(checkpoint_dir, model_name, filename) if os.path.isfile(local_path): return local_path # Fall back to Hugging Face Hub try: from huggingface_hub import hf_hub_download except ImportError as e: raise ImportError( "huggingface_hub is required to download pretrained weights. " "Install it with: pip install huggingface_hub" ) from e hf_path = f"{model_name}/{filename}" logger.info(f"Downloading {hf_path} from {HF_REPO_ID} ...") return hf_hub_download(repo_id=HF_REPO_ID, filename=hf_path)
[docs] def load_torch_model(model_name, model_dict, checkpoint_dir=None, device="cpu", checkpoint_filename=CHECKPOINT_BILBY): """ Load a pretrained PyTorch model. Weights are resolved in this order: 1. ``<checkpoint_dir>/<model_name>/<checkpoint_filename>`` (local) 2. Hugging Face Hub (``tomdooney/deepextractor``) — downloaded and cached automatically. Use the module-level constants ``CHECKPOINT_BILBY`` and ``CHECKPOINT_REAL`` for ``checkpoint_filename`` to select the correct variant. Args: model_name (str): Name of the model (must be a key in ``model_dict``). model_dict (dict): Mapping of model names to instantiated model objects. checkpoint_dir (str | None): Local directory to search first. Pass ``None`` to skip local lookup and always use Hugging Face Hub. device (str | torch.device): Device to load the model onto. checkpoint_filename (str): Checkpoint file name inside the model subdirectory. Returns: torch.nn.Module: The model with loaded weights in eval mode, or ``None`` on failure. """ try: model = model_dict[model_name].to(device) checkpoint_path = _resolve_checkpoint(model_name, checkpoint_dir, checkpoint_filename) checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) model.load_state_dict(checkpoint["state_dict"]) model.eval() logger.info(f"Successfully loaded model: {model_name}") except Exception as e: logger.error(f"Error loading model checkpoint for {model_name}: {e}") return None return model